2021.06.16.
Presenter: Sangwoo Mo
1
(ICLR 2021 spotlight)
• Goal: Find the local minima !
𝒘 that are generalizable to test samples
• Method: Minimize the sharpness-aware objective (SAM) instead of ERM
• “Flat minima” improves generalization (e.g., [1])
• With some approximation, SAM objective is computed by 2-step gradient descent
where !
𝝐(𝒘) is a function of the gradient on original weights ∇𝒘𝐿"(𝒘)
• Results: Used for recent SOTA methods (e.g., NFNet [2], {ViT, MLP-Mixer}-SAM [3])
• SAM consistently improves classification tasks, particularly when with label noises
• With SAM, ViT outperforms ResNet of the same size (MLP-Mixer on par) – by [3]
1 minute summary
2
[1] Fantastic Generalization Measures and Where to Find Them. ICLR 2020.
[2] High-Performance Large-Scale Image Recognition Without Normalization. ICML 2021.
[3] When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations. Under review.
• Sharpness-aware minimization (ICLR 2021)
• + Brief introduction of other “flat minima” methods
• When Vision Transformers Outperform ResNets (under review)
• SAM is substantially effective for ViT and MLP-Mixers!
Outline
3
• Given a training dataset 𝑆 from data distribution 𝒟, generalization bound aims to
guarantee the population loss 𝐿𝒟 with the sample loss 𝐿"
• Here, the gap of 𝐿𝒟 and 𝐿" is bounded by a function of complexity measure
• e.g., VC dimension, norm-based, distance from initialization, optimization-based
• In particular, the flatness-based measures correlate well with empirical results [1]
• Theorem. With high probability over 𝑆, the flatness-based bound says:
Flatness-based generalization bounds
4
Strictly increasing function of 𝒘.
Decreases as the number of
samples 𝑛 = 𝑆 increases.
[1] Fantastic Generalization Measures and Where to Find Them. ICLR 2020.
• Theorem. With high probability over 𝑆, the flatness-based bound says:
• Proof sketch) Given a prior and posterior of weights, the PAC Bayes* bound gives
• By assuming and as Gaussian distribution, and with some bounds,
• Using the fact that 𝝐 # ≤ 𝜌 with high probability (R.H.S.)
and since 𝒘 is local minima (L.H.S.)
concludes the proof.
Flatness-based generalization bounds
5
* Probably approximately correct (PAC)
• SAM aims to optimize the minimax objective
• With linear approximation, the optimal !
𝝐(𝑤) is given by
where 1/𝑝 + 1/𝑞 = 1 by dual norm lemma (scale depends on 𝜌 and ∇𝒘𝐿"(𝒘)).
• Substituting !
𝝐(𝑤) gives a gradient estimator
or simply
Sharpness-aware minimization (SAM)
6
2nd-order term is not necessary in practice.
• The gradient estimator of SAM is given by:
• Recall that SAM requires 2-step gradient descent (thus, twice slow)
• 1st for computing !
𝝐(𝑤) using ∇𝒘𝐿"(𝒘)
• 2nd for computing ∇𝒘𝐿" 𝑤 |$%&
𝝐($)
• In practice, using ∼20% of gradients (ERM else) is sufficient
• Set 𝑝 = 2-norm and neighborhood size 𝜌 = 0.05 as a default setup
Sharpness-aware minimization (SAM)
7
• Loss surface visualization [1]
(two random directions)
• Hessian spectra
Verification of the flatness
8
[1] Visualizing the Loss Landscape of Neural Nets. NeurIPS 2018.
ERM SAM
• SAM consistently improves classification tasks, particularly with label noises
• …though less significant for CNN architectures (see ViT later)
Results (i.e., SAM > ERM)
9
• Entropy-SGD [1] minimizes the local entropy near the weights
• Minimize average, not worst case of the neighborhood
• Stochastic weight averaging (SWA) [2] simply averages the intermediate weights
• Does not guarantee flat minima, but empirically works well
• Like the exponential weight averaging (EMA)
• Adversarial weight perturbation (AWP) [3] consider the similar idea of SAM, but
under the adversarial robustness setting
• Here, one can compute data and weight perturbations simultaneously
Some related works
10
[1] Entropy-SGD: Biasing Gradient Descent Into Wide Valleys. ICLR 2017.
[2] Averaging Weights Leads to Wider Optima and Better Generalization. UAI 2018.
[3] Adversarial Weight Perturbation Helps Robust Generalization. NeurIPS 2020.
• Recall: Vision Transformer (ViT) [2] uses Transformer-only architecture
and MLP-Mixer [3] replace the Transformer with MLP layers
SAM for ViT (and MLP-Mixer)
11
[1] When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations. Under review.
[2] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.
[3] MLP-Mixer: An all-MLP Architecture for Vision. Under review.
The contents afterward
are from [1]
MLP-Mixer: patch-wise MLP + channel-wise MLP
• ViT works well… but with large-scale dataset and strong data augmentations
• Q. Why it is not satisfactory under the moderate dataset size scenarios?
• A. It is an optimization problem!
1. Loss surface of ViT and Mixer is sharper than ResNet
Sharp local minima of ViT (and MLP-Mixer)
12
• ViT works well… but with large-scale dataset and strong data augmentations
• Q. Why it is not satisfactory under the moderate dataset size scenarios?
• A. It is an optimization problem!
2. The training curves of ViT are often unstable [1]
• Especially the shallow layers (e.g., embedding) are unstable
Sharp local minima of ViT (and MLP-Mixer)
13
[1] An Empirical Study of Training Self-Supervised Vision Transformers. Under review.
Accuracy dips Gradient spikes
• SAM significantly improves ViT and MLP-Mixer (in-domain acc. + OOD robustness)
ViT & MLP-Mixer + SAM
14
• SAM reduces the sharp local minima issue of the shallow layers
• Note that the Hessian norm accumulates for the shallow layers
• SAM also leads ViT and MLP-Mixer to learn sparse activations
• Less than 5% of neurons are activated for ViT-SAM
• Especially sparser for shallow layers
ViT & MLP-Mixer + SAM
15
% of
activated
neurons
• ViT-SAM produces more interpretable attention maps
ViT & MLP-Mixer + SAM
16
• “Flat minima” found by SAM is highly effective, especially for ViT and MLP-Mixer
• Lots of interesting properties emerge for ViT
→ Now stable and better than ResNet under the moderate dataset size!
Take-home message
17
Thank you for listening! 😀

Sharpness-aware minimization (SAM)

  • 1.
  • 2.
    • Goal: Findthe local minima ! 𝒘 that are generalizable to test samples • Method: Minimize the sharpness-aware objective (SAM) instead of ERM • “Flat minima” improves generalization (e.g., [1]) • With some approximation, SAM objective is computed by 2-step gradient descent where ! 𝝐(𝒘) is a function of the gradient on original weights ∇𝒘𝐿"(𝒘) • Results: Used for recent SOTA methods (e.g., NFNet [2], {ViT, MLP-Mixer}-SAM [3]) • SAM consistently improves classification tasks, particularly when with label noises • With SAM, ViT outperforms ResNet of the same size (MLP-Mixer on par) – by [3] 1 minute summary 2 [1] Fantastic Generalization Measures and Where to Find Them. ICLR 2020. [2] High-Performance Large-Scale Image Recognition Without Normalization. ICML 2021. [3] When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations. Under review.
  • 3.
    • Sharpness-aware minimization(ICLR 2021) • + Brief introduction of other “flat minima” methods • When Vision Transformers Outperform ResNets (under review) • SAM is substantially effective for ViT and MLP-Mixers! Outline 3
  • 4.
    • Given atraining dataset 𝑆 from data distribution 𝒟, generalization bound aims to guarantee the population loss 𝐿𝒟 with the sample loss 𝐿" • Here, the gap of 𝐿𝒟 and 𝐿" is bounded by a function of complexity measure • e.g., VC dimension, norm-based, distance from initialization, optimization-based • In particular, the flatness-based measures correlate well with empirical results [1] • Theorem. With high probability over 𝑆, the flatness-based bound says: Flatness-based generalization bounds 4 Strictly increasing function of 𝒘. Decreases as the number of samples 𝑛 = 𝑆 increases. [1] Fantastic Generalization Measures and Where to Find Them. ICLR 2020.
  • 5.
    • Theorem. Withhigh probability over 𝑆, the flatness-based bound says: • Proof sketch) Given a prior and posterior of weights, the PAC Bayes* bound gives • By assuming and as Gaussian distribution, and with some bounds, • Using the fact that 𝝐 # ≤ 𝜌 with high probability (R.H.S.) and since 𝒘 is local minima (L.H.S.) concludes the proof. Flatness-based generalization bounds 5 * Probably approximately correct (PAC)
  • 6.
    • SAM aimsto optimize the minimax objective • With linear approximation, the optimal ! 𝝐(𝑤) is given by where 1/𝑝 + 1/𝑞 = 1 by dual norm lemma (scale depends on 𝜌 and ∇𝒘𝐿"(𝒘)). • Substituting ! 𝝐(𝑤) gives a gradient estimator or simply Sharpness-aware minimization (SAM) 6 2nd-order term is not necessary in practice.
  • 7.
    • The gradientestimator of SAM is given by: • Recall that SAM requires 2-step gradient descent (thus, twice slow) • 1st for computing ! 𝝐(𝑤) using ∇𝒘𝐿"(𝒘) • 2nd for computing ∇𝒘𝐿" 𝑤 |$%& 𝝐($) • In practice, using ∼20% of gradients (ERM else) is sufficient • Set 𝑝 = 2-norm and neighborhood size 𝜌 = 0.05 as a default setup Sharpness-aware minimization (SAM) 7
  • 8.
    • Loss surfacevisualization [1] (two random directions) • Hessian spectra Verification of the flatness 8 [1] Visualizing the Loss Landscape of Neural Nets. NeurIPS 2018. ERM SAM
  • 9.
    • SAM consistentlyimproves classification tasks, particularly with label noises • …though less significant for CNN architectures (see ViT later) Results (i.e., SAM > ERM) 9
  • 10.
    • Entropy-SGD [1]minimizes the local entropy near the weights • Minimize average, not worst case of the neighborhood • Stochastic weight averaging (SWA) [2] simply averages the intermediate weights • Does not guarantee flat minima, but empirically works well • Like the exponential weight averaging (EMA) • Adversarial weight perturbation (AWP) [3] consider the similar idea of SAM, but under the adversarial robustness setting • Here, one can compute data and weight perturbations simultaneously Some related works 10 [1] Entropy-SGD: Biasing Gradient Descent Into Wide Valleys. ICLR 2017. [2] Averaging Weights Leads to Wider Optima and Better Generalization. UAI 2018. [3] Adversarial Weight Perturbation Helps Robust Generalization. NeurIPS 2020.
  • 11.
    • Recall: VisionTransformer (ViT) [2] uses Transformer-only architecture and MLP-Mixer [3] replace the Transformer with MLP layers SAM for ViT (and MLP-Mixer) 11 [1] When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations. Under review. [2] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. [3] MLP-Mixer: An all-MLP Architecture for Vision. Under review. The contents afterward are from [1] MLP-Mixer: patch-wise MLP + channel-wise MLP
  • 12.
    • ViT workswell… but with large-scale dataset and strong data augmentations • Q. Why it is not satisfactory under the moderate dataset size scenarios? • A. It is an optimization problem! 1. Loss surface of ViT and Mixer is sharper than ResNet Sharp local minima of ViT (and MLP-Mixer) 12
  • 13.
    • ViT workswell… but with large-scale dataset and strong data augmentations • Q. Why it is not satisfactory under the moderate dataset size scenarios? • A. It is an optimization problem! 2. The training curves of ViT are often unstable [1] • Especially the shallow layers (e.g., embedding) are unstable Sharp local minima of ViT (and MLP-Mixer) 13 [1] An Empirical Study of Training Self-Supervised Vision Transformers. Under review. Accuracy dips Gradient spikes
  • 14.
    • SAM significantlyimproves ViT and MLP-Mixer (in-domain acc. + OOD robustness) ViT & MLP-Mixer + SAM 14
  • 15.
    • SAM reducesthe sharp local minima issue of the shallow layers • Note that the Hessian norm accumulates for the shallow layers • SAM also leads ViT and MLP-Mixer to learn sparse activations • Less than 5% of neurons are activated for ViT-SAM • Especially sparser for shallow layers ViT & MLP-Mixer + SAM 15 % of activated neurons
  • 16.
    • ViT-SAM producesmore interpretable attention maps ViT & MLP-Mixer + SAM 16
  • 17.
    • “Flat minima”found by SAM is highly effective, especially for ViT and MLP-Mixer • Lots of interesting properties emerge for ViT → Now stable and better than ResNet under the moderate dataset size! Take-home message 17 Thank you for listening! 😀