SM3 Memory Efficient Adaptive Optimization Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
Stochastic Optimization Given the universe of examples (with prob. distribution) Learner starts with a function (model or ...
Stochastic Optimization Learner starts with a function Learner wants to learn At each round : Learner has decided upon Lea...
Stochastic Optimization Regret — Cumulative loss compared to any fixed vector in Convergence: as . ℝd RT = T ∑ t=1 ℓ( ̂yt,...
Stochastic Gradient Descent (1951) (Update) (Zinkevich, 2003) wt+1 = argminw ∈ ℝd ( 1 2ηt ∥w − wt∥2 2 + ℓ( ̂yt, yt) ) wt+1...
Convergence Problems in SGD Faster convergence, better condition number ∥w − wt∥2 2 ∥w − wt∥2 H = (w − wt)⊤ H(w − wt)
Adaptive Preconditioning - Adagrad (2010) • Preconditioning changes with time - adapts to the geometry of the data • Learn...
Adaptive Preconditioning - Adagrad One solution - Diagonal approximation where i.e. (component-wise): O(n) extra space, O(...
Still too much memory: Fermi C1060 K20X K40 P100 V100 P40 TPU v2 TPU v3 2011 201420132012 20162015 20192017 2018 1 2 4 8 1...
Preconditioned methods need 2x memory wt Ht Parameters Diagonal preconditioner (per parameter learning rates)
But learning rates are often correlated … Output softmax layer Attention layer Adagrad accumulators for a Transformer-Big ...
… rows and columns have similar magnitudes. Filter of 7x1x256x256 convolution Output softmax layer Adagrad accumulators fo...
SM3 Ht Ct,j Rt,i More generally, is a cover of if and In this example: the covers are the rows and columns. Store one floa...
SM3 Save Memory by Sharing Moments of Similar Magnitude SM3 1: parameters: learning rate , cover 2: initialize 3: for do 4...
Convergence of SM3 Lemma: For any • is monotonically increasing • Theorem: Assuming the loss functions are convex, for any...
Transformer-Big on WMT ’14 en fr→ 0.20 0.40 0.60 0.80 10 steSs 2.20 2.25 2.30 2.35 2.40 2.45 2.50 Adafactor 603 0.20 0.40 ...
BERT-Large 00 0.10 0.20 0.30 0.40 0.50 steSs 55% 60% 65% 70% 75% Adam (batch size: 1024) Adagrad (batch size: 1024) Adafac...
Memory Efficient Adaptive Optimization

Published on

Adaptive gradient-based optimizers such as Adagrad and Adam are crucial for achieving state-of-the-art performance in machine translation and language modeling. However, these methods maintain second-order statistics for each parameter, thus introducing significant memory overheads that restrict the size of the model being used as well as the number of examples in a mini-batch. We describe an effective and flexible adaptive optimization method with greatly reduced memory overhead. Our method retains the benefits of per-parameter adaptivity while allowing significantly larger models and batch sizes. We give convergence guarantees for our method, and demonstrate its effectiveness in training very large translation and language models with up to 2-fold speedups compared to the state-of-the-art.

https://arxiv.org/abs/1901.11150

Published in: Engineering
Memory Efficient Adaptive Optimization

  1. 1. SM3 Memory Efficient Adaptive Optimization Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer Principles Of Effective Machine-Learning (POEM)  Google Brain
  2. 2. Stochastic Optimization Given the universe of examples (with prob. distribution) Learner starts with a function (model or architecture) Learner wants to learn Example of : Squared loss — Cross-entropy— 𝒟 = {(x, y)} F(x; w) w = argmin w ∈ Ω E(x,y) ∈ 𝒟 ℓ(F(x; w), y) ℓ ℓ( ̂y, y) = |y − ̂y|2 ℓ( ̂y, y) = log(1 + exp(−y ̂y))
  3. 3. Stochastic Optimization Learner starts with a function Learner wants to learn At each round : Learner has decided upon Learner receives input Learner makes prediction Learner receives true outcome Learner computes loss and gradient Learner uses to update to get Stop when: gradients vanish, or (more usually) run out of time/patience F(x; w) w ∈ ℝ t wt Xt = {xi, i = 1,…, k} ̂yt = {F(xi; wt), i = 1,…, k} yt = {yi, i = 1,…, k} ℓ( ̂yt, yt) gt = ∇wℓ( ̂yt, yt) gt wt wt+1
  4. 4. Stochastic Optimization Regret — Cumulative loss compared to any fixed vector in Convergence: as . ℝd RT = T ∑ t=1 ℓ( ̂yt, yt) − min w∈Ω T ∑ t=1 ℓ(F(Xt; w), yt) RT /T → 0 T → ∞
  5. 5. Stochastic Gradient Descent (1951) (Update) (Zinkevich, 2003) wt+1 = argminw ∈ ℝd ( 1 2ηt ∥w − wt∥2 2 + ℓ( ̂yt, yt) ) wt+1 = wt − ηtgt RT = O( T)
  6. 6. Convergence Problems in SGD Faster convergence, better condition number ∥w − wt∥2 2 ∥w − wt∥2 H = (w − wt)⊤ H(w − wt)
  7. 7. Adaptive Preconditioning - Adagrad (2010) • Preconditioning changes with time - adapts to the geometry of the data • Learning rate is tuned for each parameter. Adagrad update: wt+1 = argminw∈ℝd ( 1 2η ∥w − wt∥2 Ht + ℓ( ̂yt, yt) ) Ht = (∑ s≤t gsg⊤ s ) 1/2 wt+1 = wt − η H−1 t gt Big caveats! • Needs storage. ( is the number of parameters) • Needs time per step for SVD, root finding, inverses etc. O(n2 ) n O(n3 )
  8. 8. Adaptive Preconditioning - Adagrad One solution - Diagonal approximation where i.e. (component-wise): O(n) extra space, O(n) time per step. wt+1 = wt − ηH−1 t gt Ht = diag (∑ s≤t gsg⊤ s ) 1/2 (Ht)ii = (∑ s≤t g2 s,i) 1/2 wt+1 = wt − η gt ∑s≤t g2 s
  9. 9. Still too much memory: Fermi C1060 K20X K40 P100 V100 P40 TPU v2 TPU v3 2011 201420132012 20162015 20192017 2018 1 2 4 8 16 32 GPU RAM TPU RAM MegatronLM Bert Large Transformer GPT2 Model Params RNN
  10. 10. Preconditioned methods need 2x memory wt Ht Parameters Diagonal preconditioner (per parameter learning rates)
  11. 11. But learning rates are often correlated … Output softmax layer Attention layer Adagrad accumulators for a Transformer-Big model for WMT’14 en-fr (color intensities in log scale).
  12. 12. … rows and columns have similar magnitudes. Filter of 7x1x256x256 convolution Output softmax layer Adagrad accumulators for an AmoebaNet-D model for ImageNet. (color intensities in log scale).
  13. 13. SM3 Ht Ct,j Rt,i More generally, is a cover of if and In this example: the covers are the rows and columns. Store one float per . Ht,ij = ∑ s≤t g2 s,ij ̂Ht+1,ij = min(Rt,i , Ct,j) + g2 t+1,ij Rt+1,i = max j ( ̂Ht+1,ij) Ct+1,j = max i ( ̂Ht+1,ij) {Sr}k r=1 [d] Sr ⊆ [d] ∪k r=1 Sr = [d] Sr
  14. 14. SM3 Save Memory by Sharing Moments of Similar Magnitude SM3 1: parameters: learning rate , cover 2: initialize 3: for do 4: receive gradient 5: initialize for all 6: for do 7: set 8: update (Convention: 0/0 = 0) 9: for all do 10: set η {Sr}k r=1 w1 = 0; ∀r ∈ [k] : μ0(r) = 0 t = 1,…, T gt = ∇ℓt(wt) μt(r) = 0 r ∈ [k] i = 1,…, d νt(i) ← min r:Sr∋i μt−1(r) + gt(i)2 wt+1(i) ← wt(i) − η gt(i)/ νt(i) r : Sr ∋ i μt(r) ← max{μt(r), νt(i)}
  15. 15. Convergence of SM3 Lemma: For any • is monotonically increasing • Theorem: Assuming the loss functions are convex, for any i ∈ [d] ν1(i), …νt(i) ∑ s≤t g2 s (i) ≤ νt(i) ℓ1, …, ℓt w* ∈ ℝd t ∑ s=1 (ℓs(ws) − ℓs(w*)) ≤ 2D d ∑ i=1 νt(i)
  16. 16. Transformer-Big on WMT ’14 en fr→ 0.20 0.40 0.60 0.80 10 steSs 2.20 2.25 2.30 2.35 2.40 2.45 2.50 Adafactor 603 0.20 0.40 0.60 0.80 10 steSs 2.20 2.25 2.30 2.35 2.40 2.45 2.50 Adagrad Adam Adafactor 603 Batch size 384 Batch size 768 -log(perplexity)
  17. 17. BERT-Large 00 0.10 0.20 0.30 0.40 0.50 steSs 55% 60% 65% 70% 75% Adam (batch size: 1024) Adagrad (batch size: 1024) Adafactor (batch size: 1024) 603 (batch size: 1024) 603 (batch size: 2048) Accuracy

