2. Multi-task learning is a mechanism of learning multiple tasks in
parallel, by a shared model. The goal of MTL is to improve
generalization by taking advantage of domain-specific knowledge
contained in the training samples of related tasks. MTL approach
learns the tasks more quickly and proficiently with less amount of
data and reduced overfitting than learning them independently.
However there some design challenges involved in multi-task
learning. Different tasks may have conflicting needs. Trying to
improve the performance of one task may harm the performance
of another task, a phenomenon known as negative transfer.
Combining the losses is another challenge.
3. In the context of deep learning,
multitask learning is done by hard
parameter sharing that relates to
the architecture design of sharing
the model weights between tasks
to minimize multiple loss
functions.
Soft parameter sharing relates to
having individual task specific
models with separate weights for
different tasks, but the distance
between the model parameters of
tasks are added to the joint
objective function.
4. Shared trunk:
Most of the multitasking architectures for computer vision followed a
template architecture, with a series of convolutional layers that are shared
between all the tasks acting as a base feature extractor followed by task
specific output heads which use the extracted features as input.
Multi-task Architectures
5. Multi-gate Mixture of Experts:
Multi-gate mixture of experts explicitly learns to model task relationships
from data. The structure is a mixture of expert sub-models shared across
all the tasks while having a gating network to optimize each task.
6. Cross-Talk:
Cross-talk architecture has a separate
network for each task, with information
flow between parallel layers in the task
networks.
Each task in a cross-stitch network
shown is composed of an individual
network. The input to each layer is a
linear combination of outputs from each
task network of the previous layer. The
weights are learned and are task specific.
7. Prediction Distillation:
Prediction distillation techniques are based on the ideas of multi-task learning,
features learned from one task may be helpful in performing another related
task.
PAD-net is a prediction distillation architecture, where the preliminary
predictions are made for each task and then the predictions are combined using
a multi-modal network to produce a refined final output.
8. NLP Architecture, BERT for MTL:
MT-BERT model has shared layers based on BERT. The input sequences are
converted into a sequence of embedding vectors by BERT model in the shared
layers. Contextual information is gathered by applying the attention mechanism.
The BERT model encodes the information in a vector for each token. For each
task, a fully connected task specific layer is used on top of BERT layer.
9. Multi-Modal Architecture:
OmniNet is a unified multimodal architecture to enable learning multi-modal tasks with
multiple input domains and support generic multi-tasking for any set of tasks. OmniNet
consists of multiple peripheral networks connected to a common Central Neural
Processor(CNP). Each peripheral network is used to connect domain specific input into
feature representation. The output of CNP is then passed to several task specific output
heads.
10. Learned Architectures
Learning to learn the architecture and weights of the resulting model
is another approach to architecture design of multi-task learning. A
high degree of sharing happens between similar tasks than unrelated
tasks, which can overcome the issue of negative transfer between
tasks.
In the case of Learned branching architecture, each task shares all
layers of the network at the beginning of the training. As the training
goes on, less related tasks branch into clusters, so that only highly
related tasks share as many parameters.
11. The key techniques for multi-task neural networks are shared feature
extractors followed by task specific layers, learning what to share,
fine-grained parameter sharing, varied parameter sharing, and
sharing and recombination.