Multilingual models jointly pretrained on multiple languages have achieved remarkable performance on various multilingual downstream tasks. Moreover, models finetuned on a single monolingual downstream task have shown to generalize to unseen languages. In this paper, we first show that it is crucial for those tasks to align gradients between them in order to maximize knowledge transfer while minimizing negative transfer. Despite its importance, the existing methods for gradient alignment either have a completely different purpose, ignore inter-task alignment, or aim to solve continual learning problems in rather inefficient ways. As a result of the misaligned gradients between tasks, the model suffers from severe negative transfer in the form of catastrophic forgetting of the knowledge acquired from the pretraining. To overcome the limitations, we propose a simple yet effective method that can efficiently align gradients between tasks. Specifically, we perform each inner-optimization by sequentially sampling batches from all the tasks, followed by a Reptile outer update. Thanks to the gradients aligned between tasks by our method, the model becomes less vulnerable to negative transfer and catastrophic forgetting. We extensively validate our method on various multi-task learning and zero-shot cross-lingual transfer tasks, where our method largely outperforms all the relevant baselines we consider.
Sequential Reptile_Inter-Task Gradient Alignment for Multilingual Learning
1. Sequential Reptile: Inter-Task Gradient
Alignment for Multilingual Learning
Seanie Lee, Hae Beom Lee, Juho Lee, Sung Ju Hwang
2. Data Scarcity
There are not enough labeled data for non-English languages.
2
Finnish
Indonesian
Bengali
Telugu
Yoruba
Swahili
3. Pretrained Multilingual Language Model
Language models pretrained on multilingual corpus shows impressive
performance on low resource languages.
3
Multilingual BERT
XLM
Multilingual T5
4. Multi-Task Learning β (1)
Assuming there is a common structure across languages, we can levera
ge multi-task learning to mitigate data scarcity.
4
Finnish
Indonesian
Bengali
Telugu
Yoruba
Swahili
Multilingual Model
5. Multi-Task Learning β (2)
Given π tasks, we want to estimate a parameter π minimizing the sum
of each task loss.
5
π
β¦
6. Catastrophic Forgetting
Finetuning pretrained language models leads to the catastrophic
forgetting of knowledge from pretraining [1,2].
6
[1] Lee et al., 2020. Mixout: Effective Regularization to Finetune Large-scale Pretrained Language Models. ICLR 2020.
[2] Chen et al., 2020. Recall and Learn: Fine-tuning Deep Pretrained Language Models with Less Forgetting. EMNLP 2020.
Philadelphia has more murals tha
n any other U.S. city, thanks in par
t to the 1984 β¦
Which city has more mura
ls than any other city?
Question
Paragraph
Philadelphia
[MASK] city has more murals
than any other [MASK]?
??
7. Gradient Alignment and Conflict
For MTL, we need to maximize knowledge transfer between languages
and minimize negative interference.
7
π
We need to align task gradients and avoid gradient conflict, which prev
ents models from memorizing task specific knowledge.
Gradient Conflict
π
Gradient Alignment
8. Related Works to Gradient Alignment
8
PCGrad [3] and GradVac [4] manually alter gradients to aggressively
minimize MTL objective.
[3] Yu et al., 2020. Gradient Surgery for Multi-Task Learning. NeurIPS 2020.
[4] Wang et al., 2021. Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Model. ICLR 2021.
PCGrad GradVac PCGrad GradVac
9. Explicit Gradient Alignment
Explicitly maximizing dot product of task gradients is expensive.
It requires to compute Hessian of the model parameters π.
9
10. Implicit Gradient Alignment
10
Reptile [5] shows that SGD implicitly aligns gradients of mini-batches
within a task without any second order derivatives.
[5] Nichol et al., 2020. On First-Order Meta-Learning Algorithms. ArXiv 2018.
π!
π"
11. Limitation of Reptile
11
Reptile performs inner optimization independently for each task. It can
not align gradients across tasks.
π!
π"
($)
π!
π"
(")
12. Sequential Reptile
12
We propose Sequential Reptile where inner trajectory consists of mini-
batches from all tasks. π!
π"
13. Experimental Setup
β’ Tasks
- Multilingual NLP tasks (QA, NER, NLI)
- Each language serves as a distinct task for MTL.
β’ Baselines
1) Base MTL
2) PCGrad
3) GradVac
4) RecAdam [6]
5) GradNorm [7]
6) Reptile
13
[6] Chen et al., 2020. Recall and Learn: Fine-tuning Deep Pretrained Language Models with Less Forgetting. EMNLP 2020.
[7] Chen et al., 2018. GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks. ICML 2018.
14. Experimental Result - QA
14
We train multilingual-BERT (m-BERT) on TYDI-QA datasets for question
answering.
15. Experimental Result - NER
15
We train multilingual-BERT (m-BERT) on WikiAnn datasets for named
entity recognition.
17. Analysis β (2)
17
Sequential Reptile achieves low masked language modeling loss and
low l2 distance from the pretrained model.
18. Analysis β (3)
18
To verify Sequential Reptile general knowledge across languages,
we perform zero-shot cross-lingual transfer experiments.
Seen Languages: ar, bn, en, fi, id, ko, ru, sw, te
19. Zero Shot Cross Lingual Transfer
Train a mBERT on English labeled data and evaluate it on
unseen languages.
19
We partition English data into four disjoint clusters and consider each
cluster as a task.
20. Experimental Result - QA
20
We train multilingual-BERT (m-BERT) on SQuAD and evaluate it on
MLQA datasets for question answering.
21. Experimental Result - NLI
21
We train multilingual-BERT (m-BERT) on MNLI and evaluate it on
XNLI datasets for NLI.
22. Experimental Result β Image Classification
22
We finetune ResNet18 pretrained on ImageNet on 8 different image
classification datasets.
24. Conclusion
β’ We observe that gradient alignment is important for knowledge trans
fer and preventing catastrophic forgetting.
β’ We propose an efficient algorithm to align task gradients without co
mputing second order derivative.
β’ We verify efficacy of Sequential Reptile on various datasets, including
NLP and vision tasks.
24