SlideShare a Scribd company logo
1 of 24
Download to read offline
FLAMINGO
Understanding
A Visual Language Model for Few-Shot Learning
A DeepMind paper
paper: Flamingo: a Visual Language Model for Few-Shot Learning
OUTLINE
Visual Language Models (VLM) capable of adaptation to novel tasks with few-shot learning
(simply by prompting the model with task-specific examples)
Key architectural innovations:
• Bridge powerful pretrained language-only and vision-only models → Gated X-
Attention
• Handle sequences of arbitrarily interleaved visual and textual data
• Seamlessly ingest images or videos as inputs.
2
Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
F L A M I N G O A R C H I T E C T U R E
Responsible for providing
fixed feature set of visual
features for an image or
video
3
Pre-trained NFNet
(Normalizer-Free ResNet)
F6 version
Pretrained following CLIP
with NFNet and Chincilla
Pre-trained DeepMind
Chinchilla LM layers
Input format
Responsible for learning cross-attention scores
among visual and textual features
Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
OPENFLAMINGO
Flamingo code and model are not open source currently.
However, thanks to amazing open-source AI research community, we
have an open-source implementation for the same – OpenFlamingo
We will explore OpenFlamingo to understand the Visual Language Model
world!
Paper - OpenFlamingo: An Open-Source Framework for Training Large Autoregressive Vision-Language Models
Code: mlfoundations/open_flamingo: An open-source framework for training large multimodal models. (github.com)
Dataset:
• LAION-2B - an open-source web-scraped dataset consisting of 2B image-
text pairs
• MMC4 - Multimodal C4 (MMC4), an open-source dataset of 101M
interleaved samples of images and sentences in documents soft-aligned
using CLIP.
• Synthetic: ChatGPT - 417K image-text sequences generated by prompting
ChatGPT to generate a sequence of interleaved text and image alt-texts (in
place of images).
A R C H I T E C T U R E
V I S I O N E N C O D E R
Data Set used by Vision Encoder pre-training :
• ALIGN dataset - Large-scale dataset with ~1.8 billion
image-text pairs obtained from image alt-text pairs,
leveraging the inherent noisy nature of alt-text data.
• LTIP dataset - Long Text & Image Pairs dataset. 312 million
image-text pairs with longer descriptions compared to
other datasets used for training
Based on Contrastive pre-training from scratch following CLIP paper;
CLIP – Contrastive Language-Image Pre-training
Chinchilla (Flamingo)
GPT-2 (in CLIP)
NFNet N6 (Flamingo)
ResNet (in CLIP)
Core idea of the pre-training (CLIP) is,
1. given a batch of image-text pairs,
2. compare all text and image embedding combinations
using cosine similarity,
3. train to maximize the score between matching text-image
pairs and minimize the score between incorrect ones.
(also called, Contrastive learning)
CLIP: Learning Transferable Visual Models From Natural Language Supervision (arxiv.org)
NFNet: High-Performance Large-Scale Image Recognition Without Normalization (arxiv.org)
A R C H I T E C T U R E
V I S I O N E N C O D E R
OpenFlamingo is based on OpenCLIP paper, code repository and models.
OpenCLIP - Reproducible scaling laws for contrastive language-image learning
Like CLIP, it is based on twin tower model:
• Text Tower - Responsible for encoding the text input. Based on
Transformer based encoder-decoder or decoder style language models.
Applies pooling layer to get final features for comparison.
• Several Language Models have been trained. Examples are
Roberta, XLM-Roberta, MT5, etc.
• Vision Tower – Responsible for encoding the image input.
• 2 vision model architectures typically
• Modified ResNet
• VIT
• Global average pooling is replaced by Transformer-style multi-
head QKV attention where the query is conditioned on the
global average-pooled representation of the image
The architecture together with the objective function tries to increase the
score of the diagonals and reduce the remaining scores.
Open FLAMINGO Implementation
Let’s explore OpenCLIP blocks in a little more detail!
ARCHITECTURE - VISION ENCODER
Open CLIP - Vision tower - Vision Transformer (ViT) (1/2)
ViT Paper: AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
Open FLAMINGO Implementation
Image (H,W,C=3), patch size (pw, ph)
Grid height (GH) = H/ph
Grid width (GW) = W/pw
# of patches = (#PH, #PW) = P
X = [B, D, P, P]
Convolution Operation
(kernel = patch size, stride = patch size, out channels = D)
CLS = [D]
class embedding = learnable param
X = [B, D, P,P] ➔ [B, D, PxP] → [B, P x P, D]
Tensor transforms
X = [B,(P x P)+1,D]
CONCAT(X, CLS)
pos = [(P x P)+1, D]
Positional embedding
2D learnable or 2D sinusoidal
X = [B,(P x P)+1,D]
X + POS
concatenate
add
Typically,
H = w
Ph = pw
So,
#PH = #PW = P
P
P
D
B
For example,
H = 224 W = 224 C = 3
(e.g., RGB)
Pw = ph = 14
#PH = #PW = 224 /14 =
16
So, for an image of
224x224 with 3
channels
If we take patch size of
14 x 14 i.e., each patch
is 14 x 14 (x 3) pixels
Then, number of
patches = 16 x 16
ARCHITECTURE - VISION ENCODER
AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
Open FLAMINGO Implementation
X = [B, (PxP)+1, D]
Patch Dropout
(prob, exclude class embedding)
https://arxiv.org/abs/2212.00794
Randomly drops patch embedding
features based on probability,
replaces it with random values
with the possibility to exclude the
class embedding from drops
X = [B, (PxP)+1, D]
Layer Norm
tokens = [B, (PxP)+1, D]
Transformer(# layers, # heads, GeLU,…)
Q = K = V = X
pool = [B, # queries, out dim]
Attention Pooling
Multi Head Attention
Q = learnable param = [# queries, out dim]
K = X, V = X
Analogous to Max Pooling/Avg. Pooling in the convolutional models.
# queries signify the type of pooling we are interested in.
If # queries < (GxG)+1, we are doing some sort of aggregate pooling using attention
mechanism.
If # queries = 1, we are doing sort-of global pooling into a single D-dimensional
vector. This setup is typically used in contrastive learning tasks.
Notice, that Q has out dim features whereas K and V as D features.
return (pool, tokens)
pool = [B, # queries, output dim]
Linear output projection(output dim)
X = [B,(PxP)+1,D]
Open CLIP - Vision tower - Vision Transformer (ViT) (2/2)
For CLIP. We need global representation of the image; hence, pooling layer is added
ARCHITECTURE - VISION ENCODER
Open CLIP - Text tower – Transformer based Language models e.g., Roberta or T5
[Input ids] = [B, N, E]
Input ids = Tokenizer(Text)
[B, N]
Text
Hugging Face Transformers
transformer = AutoModel.from_pretrained()
out.[last_hidden_state] = [B, N, hidden dim]
out = transformer.forward(input ids)
[B, hidden dim]
Pooling (based on configuration) while considering masking
Mean pooling, Max Pooling, CLS pooling
[Mean pooling] = [B, hidden dim]
[Max pooling] = [B, hidden dim]
CLS pooling = [B, hidden dim]
input shape: [B,N, hidden dim]
Mean pooling – Returns mean values along the
sequence length dimension (N)
Max pooling - Returns max values along the
sequence length dimension (N)
CLS pooling – Returns the [CLS] token
representation from the hidden state, which is
typically the 0th-index along the sequence length
diemension (N)
[B, hidden dim] → [B, output dim]
Linear projection / MLP
OpenCLIP implementation
ARCHITECTURE - VISION ENCODER
Idea of the CLIP paper is,
1. given a batch of image-text pairs,
2. compare all text and image embedding
combinations using cosine similarity,
3. train to maximize the score between matching
text-image pairs and minimize the score between
incorrect ones. (core idea of Contrastive learning)
Objective function has 2 contrastive losses – one from text to image and the other from
image to text.
𝐿𝑐𝑜𝑛𝑡𝑟𝑎𝑠𝑡𝑖𝑣𝑒:𝑡𝑥𝑡2𝑖𝑚𝑔 = −
1
𝑁
෍
𝑖
𝑁
log (
exp(𝐿𝑖
𝑇
𝑉𝑖 𝛽)
σ𝑗
𝑁
exp(𝐿𝑖
𝑇
𝑉
𝑗𝛽)
)
𝐿𝑐𝑜𝑛𝑡𝑟𝑎𝑠𝑡𝑖𝑣𝑒:𝑖𝑚𝑔2𝑡𝑥𝑡 = −
1
𝑁
෍
𝑖
𝑁
log (
exp(𝑉𝑖
𝑇
𝐿𝑖 𝛽)
σ𝑗
𝑁
exp(𝑉𝑖
𝑇
𝐿𝑗𝛽)
)
Pseudo code for the core of an
implementation of CLIP
𝐿𝑖 = normalized 𝑙𝑎𝑛𝑔𝑢𝑎𝑔𝑒 𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 𝑜𝑓 𝑡ℎ𝑒 i-th 𝑒𝑙𝑒𝑚𝑒𝑛𝑡 in the batch
𝑉𝑖 = normalized vision embedding of the i-th element in the batch
𝛽 = trainable inverse temperature parameter.
𝑁 = # of elements in the batch
OpenCLIP implementation
Source: CLIP: Learning Transferable Visual Models From Natural Language Supervision (arxiv.org)
Open CLIP – Contrastive Learning
F L A M I N G O A R C H I T E C T U R E
P E R C E I V E R R E S A M P L E R
• Provides fixed number of visual tokens for cross-attention from varying-size feature maps from Vision encoder
• Input – variable number of image or video features from the vision encoder [image index, features, feature dimensions]
• Output – fixed number of visual tokens based on the # of latent queries.
• Reduces the computational complexity for the vision-language cross-attention
tokens = [B, (PxP)+1, D]
Open FLAMINGO implementation
x.Shape = [B, T, #latents, D]
Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
Currently, OpenFlamingo
implements only 1 Frame videos in
Perceiver Resampler.
F L A M I N G O A R C H I T E C T U R E
G A T E D X - A T T E N T I O N
• Interleaving of
• pre-trained and frozen text-only transformer-based
Language Model decoder blocks
With
• Blocks trained from scratch that cross-attend to the
visual output from the Perceiver Resampler.
• To ensure that at initialization, the conditioned model yields the
same results as the original language model, tanh-gating
mechanism is used.
• For the tanh(𝛼) gating, 𝛼 is a layer-specific learnable scalar
parameter initialized to 0.
Plot of tanh
Open FLAMINGO implementation
Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
F L A M I N G O A R C H I T E C T U R E
G A T E D X - A T T E N T I O N - M A S K I N G
• Text is processed by replacing vision data (e.g. images) with an <image> tag and the text is divided into chunks using the <EOC>
(end of chunk) tag. Each chunk contains at most one image, which is always at the start of the chunk signifying that the
subsequent text is assumed to relate only to that image. The beginning of a sequence is also marked with the <BOS> (beginning
of sentence) tag.
• The sequence is now tokenized; each token is given an index (𝜙) corresponding to the index of the preceding image. 0 means
no preceding image.
• Now, using 𝜙 a mask is created that ensures the text tokens (𝑌) can only cross-attend to the image tokens (𝑋) that correspond
to their chunk. The mask is illustrated by dark blue entries (unmasked/visible) and light blue entries (masked).
• For cross-attention, 𝑄𝑢𝑒𝑟𝑦 = 𝑌, 𝐾𝑒𝑦 = 𝑉𝑎𝑙𝑢𝑒 = [𝑋]
Open FLAMINGO implementation
Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 1 / 7 )
• Concept of Mixin is used to add functionality of gated cross-attention to a frozen
language model.
• Extend_instance() is an implementation of mixin in python.
• obj is the object whose functionality we want to extend.
• mixin is the class whose functionality we want to add to obj.
• Using the type(), a new class is dynamically created that inherits from both the base
class and the mixin using the 2nd argument. The 3rd argument is for any additional
attributes for the new class.
• Here, the extend_instance() is used to add gated cross-attention functionality
between vision and language inputs:
Let’s understand FlamingoLMMixin now!
Open FLAMINGO implementation
Code: mlfoundations/open_flamingo: An open-source framework for training large multimodal models. (github.com)
F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 2 / 7 )
Open FLAMINGO implementation
MaskedCrossAttention
Implements cross-
attention between text
and image with masking
GatedCrossAttentionBlock
Adds tanh gating over the
masked cross attention
FlamingoLMMixin (nn.Module) pseudo code
init_flamingo_layers()
#causal LLM decoder layers
decoderBlocks = List[CausalLlmDecoderBlock]
gatedXAttentionBlocks = List[GatedCrossAttentionBlock]
For layerIndex in decoderBlocks:
If [layerIndex] % cross_attn_every_n_layers:
gatedXAttentionBlocks.append(GatedCrossAttentionBlock())
Else:
gatedXAttentionBlocks.append(None)
flamingoLayers = List[FlamingoLayer]
For d,x in zip(decoderBlockers, gatedXAttentionBlocks):
flamingoLayers.append(FlamingoLayer(d, x)
CausalLlmDecoderBlock.decoderLayers = flamingoLayers
forward()
calls super.forward() which in this case will CausalLlm
(decoder) model’s forward().
As the CausalLlm (decoder) model’s layers are replaced with
FlamingoLayer, it’s forward() will be called.
FlamingoLayer
Stores both
GatedCrossAttentionBlock and
decoderBlock for a layer
F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 3 / 7 )
Open FLAMINGO implementation
FlamingoLayer (nn.Module) pseudo code
self.gated_x_attention:GatedCrossAttentionBlock
self.decoder_layer: original LLM Decoder Layer
condition_vis_x(vision_features)
vision features from the vision encoder
self.vis_x = vision_features
Condition_media_locations(media_locations)
input ids where input id == input id of <image> token.
self.media_locations = media_locations
forward(lang_x:[B, L, txt_dim], attention_mask: [B, L]):
As cross attention layers are added every n layers based on
config, if gated cross attention is not none, call it first, then,
call the LLM decoder layer, otherwise directly called LLM decoder
layer.
if self.gated_x_attention is not None:
lang_x = self.gated_x_attention(lang_x, self.vis_x,
self.media_locations)
# lang_x.shape = [B, L, txt_dim]
lang_x = self.decoder_layer(lang_x, attention_mask)
return lang_x #lang_x.shape = [B, L, txt_dim]
FlamingoLayer
Stores both
GatedCrossAttentionBlock and
decoderBlock for a layer
The LLM decoder layers are
replaced with FlamingoLayer
by FlamingoMixin mixin.
F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 4 / 7 )
Open FLAMINGO implementation
MaskedCrossAttention (nn.Module) pseudo code
Init()
num_heads #number of attention heads
head_dim #dimensionality of each attention head
vis_dim, #visual features dimensionality
dim #input and output dimension of this module.
forward(text_features:[B,L,txt_dim], visual_features[B,T,G,vis_dim], visual_locations_in_text:bool[B,L])
text_features = layerNorm(text_features) # [B, L, txt_dim]
hidden_dim = num_heads * head_dim
q = LinearProjection(text_features) → dim to hidden_dim #[B,L,hidden_dim]
visual_features = rearrange(visual_features, (B,(T * G), vis_dim)) # [B,(T*G), vis_dim]
k = LinearProjection(visual_features) → vis_dim to hidden dimension → #[B,T*G,hidden_dim]
v = LinearProjection(visual_features) → vis_dim to hidden dimension → #[B,T*G,hidden_dim]
rearrange(q,k,v,from: (B,n,hidden_dim), to: (B,num_heads,n,head_dim))
now, q,k,v has shape [B,num_heads,n,head_dim] where n = L or T*G for text and image resp.
score = 𝒒 ∗ 𝒌𝑻
# (B, num_heads,L,head_dim) * (B,num_heads,head_dim,T*G) → shape: [B, num_heads, L, T*G]
# now we want to nullify some scores based on which image index (T) [<=T all past images up to current text
token # time or =T only immediate last image to the current text token] should be considered for scoring or not.
media_time = torch.arange(T) + 1 # let’s say T = 2; media_time = [1,2]
# let’s say, L=5,B=2
# visual_locations_in_text = tensor([[True,True,True,False,True], [False,False,True,False, True]])
text_time = visual_locations_in_text.cumsum(dim=-1) # cumulative summation
text_time = tensor([[1, 2, 3, 3, 4], [0, 0, 1, 1, 2]])
MaskedCrossAttention
Implements cross-
attention between text
and image with masking
F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 5 / 7 )
Open FLAMINGO implementation
MaskedCrossAttention
Implements cross-
attention between text
and image with masking
GatedCrossAttentionBlock
Adds tanh gating over the
masked cross attention
MaskedCrossAttention (nn.Module) pseudo code
…
forward(text_features:[B,L,txt_dim], visual_features[B,T,G,vis_dim,visual_locations_in_text:bool[B,L])
…
# prepare for comparison of text token and image token along time dimension
re_text_time = einops.rearrange(text_time, "b i -> b 1 i 1")
# re_text_time = tensor([[[[1],[2],[3],[3],[4]]], [[[0], [0], [1], [1], [2]]]])
# re_text_time.shape = [2, 1, 5, 1]
# say, T = visual_features[2] = 8 # i.e. G
re_media_time = einops.repeat(media_time, "j -> 1 1 1 (j n)", n=visual_features[1]) # media_time = [1,2]
# re_media_time = tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]]]])
# re_media_time.shape = [1, 1, 1, 16]
if only immediate preceding image (index T) should be allowed
text_to_media_mask = torch.eq(re_text_time, re_media_time)
else #for all past images (index T) i.e. as long as text token position is greater than the media one.
text_to_media_mask = torch.ge(re_text_time, re_media_time)
# text_to_media_mask.shape =[2, 1, 5, 16]; generically, text_to_media_mask.shape = [B,1,L,T*G]
# when operator is torch.eq; showing 1 entry for brevity – 16 entries; each entry in re_text_time (here, [1]) is
compared (equality) with all the 16 entries in re_media_time (here, [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2,
2]) using broadcasting; so,for 1st entry: 1 eq 1 = True
# text_to_media_mask = tensor([[[[True,True,True,True,True,True,True,True,
False,False,False,False,False,False,False,False], … ]]])
# when operator is torch.ge; showing 1 entry for brevity – 16 entries; each entry in re_text_time (here, [2]) is
# compared(ge) with all the 16 # entries in re_media_time (here, [1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2]); so, for 1st
entry: 2 ge 1 = True
# text_to_media_mask = tensor([[[[True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True], … ]]])
F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 6 / 7 )
Open FLAMINGO implementation
MaskedCrossAttention (nn.Module) pseudo code
…
forward(text_features:[B,L,txt_dim], visual_features[B,T,G,vis_dim,visual_locations_in_text:bool[B,L])
…
score.shape = [B, num_heads, L, T*G]
text_to_media_mask.shape = [B, 1, L, T*G]
# now, update the score with a large negative number wherever the mask is False.
# we use masked_fill() which fills based on wherever the mask is True; hence, we invert the mask with ~
score = score.masked_fill(~text_to_media_mask, -torch.finfo(score.dtype).max)
score = score - score.amax(dim=-1, keepdim=True)
attn = score.softmax(dim=-1)
if exists(visual_locations_in_text) and only attend immediately preceding image
# if text does not have any preceding image, then attention score should be zeroed.
text_without_media_mask = text_time == 0
text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
attn = attn.masked_fill(text_without_media_mask, 0.0)
#attn.shape = [B, num_heads, L, T*G]
#v.shape = [B,num_heads,T*G, head_dim]
out = attn * 𝒗
#out.shape = [B, num_heads, L, head_dim]
reshape out → [B, L, (num_heads * head_dim)]
out = LinearProjection(out) to dim → out.shape = [B, L, dim]
return out
F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 7 / 7 )
Open FLAMINGO implementation
GatedCrossAttentionBlock (nn.Module) pseudo code
init()
self.crossAttention = MaskedCrossAttention()
self.attn_gate = LearnableParameter(initial Value = 0.0)
self.feedforward = layerNorm(dim(x))→ Linear(dim(x), 4*dim(x))→ GeLU(x), Linear(4*dim(x), dim(x)))
self.feedforward_gate = LearnableParameter(initial Value = 0.0)
forward(text_features:[B,L,txt_dim], visual_features[B,T,G,vis_dim,visual_locations_in_text:bool[B,L]
crossAttentionScores = self.crossAttention(text_features, visual_features, visual_locations_in_text)
# crossAttentionScores.shape = [B, L, txt_dim]
# tanh() to control the magnitude of cross attention scores – beginning with no cross-attention scores
# as tanh(0.0) = 0.0 so text features from frozen LLM remains as is. However, with more epochs, the
# parameter value attn_gate adapts to control the magnitude.
crossAttentionScores = crossAttentionScores * math.tanh(self.attn_gate) # [B, L, txt_dim]
x = text_features + crossAttentionScores # [B, L, txt_dim]
x = self.feedforward(x) # [B, L, txt_dim]
# Like attn_gate however, for the feedforward MLP here.
x = (x * math.tanh(self.feedforward_gate)) + x # [B, L, txt_dim]
return x # [B, L, txt_dim]
Plot of tanh
F L A M I N G O A R C H I T E C T U R E : F L A M I N G O
Open FLAMINGO implementation
Flamingo (nn.Module) pseudo code
init()
self.visionEncoder = OpenCLIP.VisionTower
self.perceiver = PerceiverResampler #see Perceiver Resampler slide
self.llm = CausalLlm #any decoder style causal LLM
self.llm = extend_instance(self.llm, FlamingoMixin) #see FlamingoMixin slide
self.Llm.init_flamingo_layers(…) #see FlamingoMixin slide
self.media_token_id = token Id of <Image>
forward(text_inputIds:[B,L], image_inputs:[B, T_img, F=1, C, H, W]):
# Images in the same chunk are collated along T_img i.e., # of images in a text chunk <|endofchunk|>.
# frames are collated along F; Currently only F=1 is supported
image_inputs = rearrange(image_inputs, “B T_img F C H W -> (B T_img F) C H W")
vision_features = self.visionEncoder(image_inputs) #tokens output - see CLIP vision tower slide
# vision_features.shape = [(B * T_img * F),(GxG)+1, D]
vision_features = rearrange(vision_features, "(B T_img F) (GxG)+1 D -> B T_img F (GxG)+1 D")
perceiver_features = self.perceiver(vision_features) # [B, T_img, #latents, D]
for layer in llm.get_decoder_layers():
# see FlamingoMixin slide – as it replaces each decoder layer with FlamingoLayer
# so, each layer is an instance of FlamingoLayer
layer.condition_vis_x(perceiver_features) #see FlamingoLayer slide
visual_locations_in_text = text_inputIds == self.media_token_id
layer.condition_media_locations(visual_locations_in_text)
#now, call CausalLlm.forward() which in turn will call FlamingoLayer.forward() for each LlM decoder layer.
output = self.Llm(text_inputIds,…)
return output # output.shape = [B, L, txt_dim]
FLAMINGO TRAINING INFO
• Dataset: LAION-2B, MMC4, synthetic: ChatGPT
• Objective function:
෍
𝒎=𝟏
𝑴
𝝀𝒎 𝔼 𝒙,𝒚 ~𝑫𝒎
[ − ෍
𝒍=𝟏
𝑳
𝐥𝐨𝐠 𝒑 𝒚𝒍 𝒚<𝒍, 𝒙<𝒍)]
• Tuning the per-dataset weights 𝜆𝑚 is key to performance.
• All models trained using 64 GPUs distributed across 8 nodes using either Distributed Data
Parallel (DDP) or FSDP (Fully Sharded Data Parallel) techniques.
𝐷𝑚 = m-th dataset
𝜆𝑚 = dataset weightage
L = sequence length
𝑦 = text token
𝑥 = image token
Open FLAMINGO implementation
Paper - OpenFlamingo: An Open-Source Framework for Training Large Autoregressive Vision-Language Models
Negative Log likelihood loss i.e., minimize
the weighted sum of per-dataset (𝐷𝑚)
expectation of negative log-likelihoods of
text token (𝑦𝑙) conditioned on text (𝑦<𝑙) and
visual inputs (𝑥<𝑙) preceding the current text
token.
minimize
All openFlamingo models are trained with CLIP ViT-L/14 as
vision encoder
Flamingo
OpenCLIP
F L A M I N G O A R C H I T E C T U R E – P U T T I N G I T T O G E T H E R
Open FLAMINGO implementation
ViT Causal LLM
Constrastive learning Objective
ViT
Perceiver
Resampler
GatedCrossAttentionBlock
MaskedCrossAttention
Decoder Layer
Decoder Layer
Decoder Layer
Decoder Layer
Next token Prediction objective
[CrossEntropyLoss]
a sal LL
e oder Layer
e oder Layer
e oder Layer
e oder Layer
Tokenizer
Input webpage
Processed to
FlamingoLayer
THANK YOU
8/06/20XX PITCH DECK 24
Source: Announcing OpenFlamingo: An open-source framework for training vision-
language models with in-context learning | LAION

More Related Content

Similar to Understanding Flamingo - DeepMind's VLM Architecture

Deep Learning
Deep LearningDeep Learning
Deep Learning
Pierre de Lacaze
 

Similar to Understanding Flamingo - DeepMind's VLM Architecture (20)

Keras on tensorflow in R & Python
Keras on tensorflow in R & PythonKeras on tensorflow in R & Python
Keras on tensorflow in R & Python
 
[PR12] PR-036 Learning to Remember Rare Events
[PR12] PR-036 Learning to Remember Rare Events[PR12] PR-036 Learning to Remember Rare Events
[PR12] PR-036 Learning to Remember Rare Events
 
IN4308 1
IN4308 1IN4308 1
IN4308 1
 
Talk from NVidia Developer Connect
Talk from NVidia Developer ConnectTalk from NVidia Developer Connect
Talk from NVidia Developer Connect
 
B.tech_project_ppt.pptx
B.tech_project_ppt.pptxB.tech_project_ppt.pptx
B.tech_project_ppt.pptx
 
From Hours to Minutes: The Journey of Optimizing Mask-RCNN and BERT Using MXNet
From Hours to Minutes: The Journey of Optimizing Mask-RCNN and BERT Using MXNetFrom Hours to Minutes: The Journey of Optimizing Mask-RCNN and BERT Using MXNet
From Hours to Minutes: The Journey of Optimizing Mask-RCNN and BERT Using MXNet
 
고급컴파일러구성론_개레_230303.pptx
고급컴파일러구성론_개레_230303.pptx고급컴파일러구성론_개레_230303.pptx
고급컴파일러구성론_개레_230303.pptx
 
CS 354 Programmable Shading
CS 354 Programmable ShadingCS 354 Programmable Shading
CS 354 Programmable Shading
 
Integrative Parallel Programming in HPC
Integrative Parallel Programming in HPCIntegrative Parallel Programming in HPC
Integrative Parallel Programming in HPC
 
Deep Learning
Deep LearningDeep Learning
Deep Learning
 
Deep learning for molecules, introduction to chainer chemistry
Deep learning for molecules, introduction to chainer chemistryDeep learning for molecules, introduction to chainer chemistry
Deep learning for molecules, introduction to chainer chemistry
 
Adversarial_Examples_in_Audio_and_Text.pptx
Adversarial_Examples_in_Audio_and_Text.pptxAdversarial_Examples_in_Audio_and_Text.pptx
Adversarial_Examples_in_Audio_and_Text.pptx
 
What multimodal foundation models cannot perceive
What multimodal foundation models cannot perceiveWhat multimodal foundation models cannot perceive
What multimodal foundation models cannot perceive
 
Programming Languages & Tools for Higher Performance & Productivity
Programming Languages & Tools for Higher Performance & ProductivityProgramming Languages & Tools for Higher Performance & Productivity
Programming Languages & Tools for Higher Performance & Productivity
 
Introduction to OpenCV
Introduction to OpenCVIntroduction to OpenCV
Introduction to OpenCV
 
Domain specific languages and Scala
Domain specific languages and ScalaDomain specific languages and Scala
Domain specific languages and Scala
 
Deep Dive on Deep Learning (June 2018)
Deep Dive on Deep Learning (June 2018)Deep Dive on Deep Learning (June 2018)
Deep Dive on Deep Learning (June 2018)
 
ODSC East: Effective Transfer Learning for NLP
ODSC East: Effective Transfer Learning for NLPODSC East: Effective Transfer Learning for NLP
ODSC East: Effective Transfer Learning for NLP
 
Ehtsham Elahi, Senior Research Engineer, Personalization Science and Engineer...
Ehtsham Elahi, Senior Research Engineer, Personalization Science and Engineer...Ehtsham Elahi, Senior Research Engineer, Personalization Science and Engineer...
Ehtsham Elahi, Senior Research Engineer, Personalization Science and Engineer...
 
Semantic Segmentation - Fully Convolutional Networks for Semantic Segmentation
Semantic Segmentation - Fully Convolutional Networks for Semantic SegmentationSemantic Segmentation - Fully Convolutional Networks for Semantic Segmentation
Semantic Segmentation - Fully Convolutional Networks for Semantic Segmentation
 

Recently uploaded

Team Transformation Tactics for Holistic Testing and Quality (NewCrafts Paris...
Team Transformation Tactics for Holistic Testing and Quality (NewCrafts Paris...Team Transformation Tactics for Holistic Testing and Quality (NewCrafts Paris...
Team Transformation Tactics for Holistic Testing and Quality (NewCrafts Paris...
Lisi Hocke
 
Jax, FL Admin Community Group 05.14.2024 Combined Deck
Jax, FL Admin Community Group 05.14.2024 Combined DeckJax, FL Admin Community Group 05.14.2024 Combined Deck
Jax, FL Admin Community Group 05.14.2024 Combined Deck
Marc Lester
 

Recently uploaded (20)

^Clinic ^%[+27788225528*Abortion Pills For Sale In soweto
^Clinic ^%[+27788225528*Abortion Pills For Sale In soweto^Clinic ^%[+27788225528*Abortion Pills For Sale In soweto
^Clinic ^%[+27788225528*Abortion Pills For Sale In soweto
 
Software Engineering - Introduction + Process Models + Requirements Engineering
Software Engineering - Introduction + Process Models + Requirements EngineeringSoftware Engineering - Introduction + Process Models + Requirements Engineering
Software Engineering - Introduction + Process Models + Requirements Engineering
 
Workshop: Enabling GenAI Breakthroughs with Knowledge Graphs - GraphSummit Milan
Workshop: Enabling GenAI Breakthroughs with Knowledge Graphs - GraphSummit MilanWorkshop: Enabling GenAI Breakthroughs with Knowledge Graphs - GraphSummit Milan
Workshop: Enabling GenAI Breakthroughs with Knowledge Graphs - GraphSummit Milan
 
[GeeCON2024] How I learned to stop worrying and love the dark silicon apocalypse
[GeeCON2024] How I learned to stop worrying and love the dark silicon apocalypse[GeeCON2024] How I learned to stop worrying and love the dark silicon apocalypse
[GeeCON2024] How I learned to stop worrying and love the dark silicon apocalypse
 
Abortion Clinic In Johannesburg ](+27832195400*)[ 🏥 Safe Abortion Pills in Jo...
Abortion Clinic In Johannesburg ](+27832195400*)[ 🏥 Safe Abortion Pills in Jo...Abortion Clinic In Johannesburg ](+27832195400*)[ 🏥 Safe Abortion Pills in Jo...
Abortion Clinic In Johannesburg ](+27832195400*)[ 🏥 Safe Abortion Pills in Jo...
 
A Deep Dive into Secure Product Development Frameworks.pdf
A Deep Dive into Secure Product Development Frameworks.pdfA Deep Dive into Secure Product Development Frameworks.pdf
A Deep Dive into Secure Product Development Frameworks.pdf
 
Team Transformation Tactics for Holistic Testing and Quality (NewCrafts Paris...
Team Transformation Tactics for Holistic Testing and Quality (NewCrafts Paris...Team Transformation Tactics for Holistic Testing and Quality (NewCrafts Paris...
Team Transformation Tactics for Holistic Testing and Quality (NewCrafts Paris...
 
Abortion Clinic In Polokwane ](+27832195400*)[ 🏥 Safe Abortion Pills in Polok...
Abortion Clinic In Polokwane ](+27832195400*)[ 🏥 Safe Abortion Pills in Polok...Abortion Clinic In Polokwane ](+27832195400*)[ 🏥 Safe Abortion Pills in Polok...
Abortion Clinic In Polokwane ](+27832195400*)[ 🏥 Safe Abortion Pills in Polok...
 
Jax, FL Admin Community Group 05.14.2024 Combined Deck
Jax, FL Admin Community Group 05.14.2024 Combined DeckJax, FL Admin Community Group 05.14.2024 Combined Deck
Jax, FL Admin Community Group 05.14.2024 Combined Deck
 
Community is Just as Important as Code by Andrea Goulet
Community is Just as Important as Code by Andrea GouletCommunity is Just as Important as Code by Andrea Goulet
Community is Just as Important as Code by Andrea Goulet
 
Alluxio Monthly Webinar | Simplify Data Access for AI in Multi-Cloud
Alluxio Monthly Webinar | Simplify Data Access for AI in Multi-CloudAlluxio Monthly Webinar | Simplify Data Access for AI in Multi-Cloud
Alluxio Monthly Webinar | Simplify Data Access for AI in Multi-Cloud
 
Spring into AI presented by Dan Vega 5/14
Spring into AI presented by Dan Vega 5/14Spring into AI presented by Dan Vega 5/14
Spring into AI presented by Dan Vega 5/14
 
The Evolution of Web App Testing_ An Ultimate Guide to Future Trends.pdf
The Evolution of Web App Testing_ An Ultimate Guide to Future Trends.pdfThe Evolution of Web App Testing_ An Ultimate Guide to Future Trends.pdf
The Evolution of Web App Testing_ An Ultimate Guide to Future Trends.pdf
 
How to install and activate eGrabber JobGrabber
How to install and activate eGrabber JobGrabberHow to install and activate eGrabber JobGrabber
How to install and activate eGrabber JobGrabber
 
Auto Affiliate AI Earns First Commission in 3 Hours..pdf
Auto Affiliate  AI Earns First Commission in 3 Hours..pdfAuto Affiliate  AI Earns First Commission in 3 Hours..pdf
Auto Affiliate AI Earns First Commission in 3 Hours..pdf
 
Secure Software Ecosystem Teqnation 2024
Secure Software Ecosystem Teqnation 2024Secure Software Ecosystem Teqnation 2024
Secure Software Ecosystem Teqnation 2024
 
^Clinic ^%[+27788225528*Abortion Pills For Sale In witbank
^Clinic ^%[+27788225528*Abortion Pills For Sale In witbank^Clinic ^%[+27788225528*Abortion Pills For Sale In witbank
^Clinic ^%[+27788225528*Abortion Pills For Sale In witbank
 
Abortion Clinic Pretoria ](+27832195400*)[ Abortion Clinic Near Me ● Abortion...
Abortion Clinic Pretoria ](+27832195400*)[ Abortion Clinic Near Me ● Abortion...Abortion Clinic Pretoria ](+27832195400*)[ Abortion Clinic Near Me ● Abortion...
Abortion Clinic Pretoria ](+27832195400*)[ Abortion Clinic Near Me ● Abortion...
 
Optimizing Operations by Aligning Resources with Strategic Objectives Using O...
Optimizing Operations by Aligning Resources with Strategic Objectives Using O...Optimizing Operations by Aligning Resources with Strategic Objectives Using O...
Optimizing Operations by Aligning Resources with Strategic Objectives Using O...
 
Effective Strategies for Wix's Scaling challenges - GeeCon
Effective Strategies for Wix's Scaling challenges - GeeConEffective Strategies for Wix's Scaling challenges - GeeCon
Effective Strategies for Wix's Scaling challenges - GeeCon
 

Understanding Flamingo - DeepMind's VLM Architecture

  • 1. FLAMINGO Understanding A Visual Language Model for Few-Shot Learning A DeepMind paper paper: Flamingo: a Visual Language Model for Few-Shot Learning
  • 2. OUTLINE Visual Language Models (VLM) capable of adaptation to novel tasks with few-shot learning (simply by prompting the model with task-specific examples) Key architectural innovations: • Bridge powerful pretrained language-only and vision-only models → Gated X- Attention • Handle sequences of arbitrarily interleaved visual and textual data • Seamlessly ingest images or videos as inputs. 2 Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
  • 3. F L A M I N G O A R C H I T E C T U R E Responsible for providing fixed feature set of visual features for an image or video 3 Pre-trained NFNet (Normalizer-Free ResNet) F6 version Pretrained following CLIP with NFNet and Chincilla Pre-trained DeepMind Chinchilla LM layers Input format Responsible for learning cross-attention scores among visual and textual features Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
  • 4. OPENFLAMINGO Flamingo code and model are not open source currently. However, thanks to amazing open-source AI research community, we have an open-source implementation for the same – OpenFlamingo We will explore OpenFlamingo to understand the Visual Language Model world! Paper - OpenFlamingo: An Open-Source Framework for Training Large Autoregressive Vision-Language Models Code: mlfoundations/open_flamingo: An open-source framework for training large multimodal models. (github.com) Dataset: • LAION-2B - an open-source web-scraped dataset consisting of 2B image- text pairs • MMC4 - Multimodal C4 (MMC4), an open-source dataset of 101M interleaved samples of images and sentences in documents soft-aligned using CLIP. • Synthetic: ChatGPT - 417K image-text sequences generated by prompting ChatGPT to generate a sequence of interleaved text and image alt-texts (in place of images).
  • 5. A R C H I T E C T U R E V I S I O N E N C O D E R Data Set used by Vision Encoder pre-training : • ALIGN dataset - Large-scale dataset with ~1.8 billion image-text pairs obtained from image alt-text pairs, leveraging the inherent noisy nature of alt-text data. • LTIP dataset - Long Text & Image Pairs dataset. 312 million image-text pairs with longer descriptions compared to other datasets used for training Based on Contrastive pre-training from scratch following CLIP paper; CLIP – Contrastive Language-Image Pre-training Chinchilla (Flamingo) GPT-2 (in CLIP) NFNet N6 (Flamingo) ResNet (in CLIP) Core idea of the pre-training (CLIP) is, 1. given a batch of image-text pairs, 2. compare all text and image embedding combinations using cosine similarity, 3. train to maximize the score between matching text-image pairs and minimize the score between incorrect ones. (also called, Contrastive learning) CLIP: Learning Transferable Visual Models From Natural Language Supervision (arxiv.org) NFNet: High-Performance Large-Scale Image Recognition Without Normalization (arxiv.org)
  • 6. A R C H I T E C T U R E V I S I O N E N C O D E R OpenFlamingo is based on OpenCLIP paper, code repository and models. OpenCLIP - Reproducible scaling laws for contrastive language-image learning Like CLIP, it is based on twin tower model: • Text Tower - Responsible for encoding the text input. Based on Transformer based encoder-decoder or decoder style language models. Applies pooling layer to get final features for comparison. • Several Language Models have been trained. Examples are Roberta, XLM-Roberta, MT5, etc. • Vision Tower – Responsible for encoding the image input. • 2 vision model architectures typically • Modified ResNet • VIT • Global average pooling is replaced by Transformer-style multi- head QKV attention where the query is conditioned on the global average-pooled representation of the image The architecture together with the objective function tries to increase the score of the diagonals and reduce the remaining scores. Open FLAMINGO Implementation Let’s explore OpenCLIP blocks in a little more detail!
  • 7. ARCHITECTURE - VISION ENCODER Open CLIP - Vision tower - Vision Transformer (ViT) (1/2) ViT Paper: AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE Open FLAMINGO Implementation Image (H,W,C=3), patch size (pw, ph) Grid height (GH) = H/ph Grid width (GW) = W/pw # of patches = (#PH, #PW) = P X = [B, D, P, P] Convolution Operation (kernel = patch size, stride = patch size, out channels = D) CLS = [D] class embedding = learnable param X = [B, D, P,P] ➔ [B, D, PxP] → [B, P x P, D] Tensor transforms X = [B,(P x P)+1,D] CONCAT(X, CLS) pos = [(P x P)+1, D] Positional embedding 2D learnable or 2D sinusoidal X = [B,(P x P)+1,D] X + POS concatenate add Typically, H = w Ph = pw So, #PH = #PW = P P P D B For example, H = 224 W = 224 C = 3 (e.g., RGB) Pw = ph = 14 #PH = #PW = 224 /14 = 16 So, for an image of 224x224 with 3 channels If we take patch size of 14 x 14 i.e., each patch is 14 x 14 (x 3) pixels Then, number of patches = 16 x 16
  • 8. ARCHITECTURE - VISION ENCODER AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE Open FLAMINGO Implementation X = [B, (PxP)+1, D] Patch Dropout (prob, exclude class embedding) https://arxiv.org/abs/2212.00794 Randomly drops patch embedding features based on probability, replaces it with random values with the possibility to exclude the class embedding from drops X = [B, (PxP)+1, D] Layer Norm tokens = [B, (PxP)+1, D] Transformer(# layers, # heads, GeLU,…) Q = K = V = X pool = [B, # queries, out dim] Attention Pooling Multi Head Attention Q = learnable param = [# queries, out dim] K = X, V = X Analogous to Max Pooling/Avg. Pooling in the convolutional models. # queries signify the type of pooling we are interested in. If # queries < (GxG)+1, we are doing some sort of aggregate pooling using attention mechanism. If # queries = 1, we are doing sort-of global pooling into a single D-dimensional vector. This setup is typically used in contrastive learning tasks. Notice, that Q has out dim features whereas K and V as D features. return (pool, tokens) pool = [B, # queries, output dim] Linear output projection(output dim) X = [B,(PxP)+1,D] Open CLIP - Vision tower - Vision Transformer (ViT) (2/2) For CLIP. We need global representation of the image; hence, pooling layer is added
  • 9. ARCHITECTURE - VISION ENCODER Open CLIP - Text tower – Transformer based Language models e.g., Roberta or T5 [Input ids] = [B, N, E] Input ids = Tokenizer(Text) [B, N] Text Hugging Face Transformers transformer = AutoModel.from_pretrained() out.[last_hidden_state] = [B, N, hidden dim] out = transformer.forward(input ids) [B, hidden dim] Pooling (based on configuration) while considering masking Mean pooling, Max Pooling, CLS pooling [Mean pooling] = [B, hidden dim] [Max pooling] = [B, hidden dim] CLS pooling = [B, hidden dim] input shape: [B,N, hidden dim] Mean pooling – Returns mean values along the sequence length dimension (N) Max pooling - Returns max values along the sequence length dimension (N) CLS pooling – Returns the [CLS] token representation from the hidden state, which is typically the 0th-index along the sequence length diemension (N) [B, hidden dim] → [B, output dim] Linear projection / MLP OpenCLIP implementation
  • 10. ARCHITECTURE - VISION ENCODER Idea of the CLIP paper is, 1. given a batch of image-text pairs, 2. compare all text and image embedding combinations using cosine similarity, 3. train to maximize the score between matching text-image pairs and minimize the score between incorrect ones. (core idea of Contrastive learning) Objective function has 2 contrastive losses – one from text to image and the other from image to text. 𝐿𝑐𝑜𝑛𝑡𝑟𝑎𝑠𝑡𝑖𝑣𝑒:𝑡𝑥𝑡2𝑖𝑚𝑔 = − 1 𝑁 ෍ 𝑖 𝑁 log ( exp(𝐿𝑖 𝑇 𝑉𝑖 𝛽) σ𝑗 𝑁 exp(𝐿𝑖 𝑇 𝑉 𝑗𝛽) ) 𝐿𝑐𝑜𝑛𝑡𝑟𝑎𝑠𝑡𝑖𝑣𝑒:𝑖𝑚𝑔2𝑡𝑥𝑡 = − 1 𝑁 ෍ 𝑖 𝑁 log ( exp(𝑉𝑖 𝑇 𝐿𝑖 𝛽) σ𝑗 𝑁 exp(𝑉𝑖 𝑇 𝐿𝑗𝛽) ) Pseudo code for the core of an implementation of CLIP 𝐿𝑖 = normalized 𝑙𝑎𝑛𝑔𝑢𝑎𝑔𝑒 𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 𝑜𝑓 𝑡ℎ𝑒 i-th 𝑒𝑙𝑒𝑚𝑒𝑛𝑡 in the batch 𝑉𝑖 = normalized vision embedding of the i-th element in the batch 𝛽 = trainable inverse temperature parameter. 𝑁 = # of elements in the batch OpenCLIP implementation Source: CLIP: Learning Transferable Visual Models From Natural Language Supervision (arxiv.org) Open CLIP – Contrastive Learning
  • 11. F L A M I N G O A R C H I T E C T U R E P E R C E I V E R R E S A M P L E R • Provides fixed number of visual tokens for cross-attention from varying-size feature maps from Vision encoder • Input – variable number of image or video features from the vision encoder [image index, features, feature dimensions] • Output – fixed number of visual tokens based on the # of latent queries. • Reduces the computational complexity for the vision-language cross-attention tokens = [B, (PxP)+1, D] Open FLAMINGO implementation x.Shape = [B, T, #latents, D] Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning Currently, OpenFlamingo implements only 1 Frame videos in Perceiver Resampler.
  • 12. F L A M I N G O A R C H I T E C T U R E G A T E D X - A T T E N T I O N • Interleaving of • pre-trained and frozen text-only transformer-based Language Model decoder blocks With • Blocks trained from scratch that cross-attend to the visual output from the Perceiver Resampler. • To ensure that at initialization, the conditioned model yields the same results as the original language model, tanh-gating mechanism is used. • For the tanh(𝛼) gating, 𝛼 is a layer-specific learnable scalar parameter initialized to 0. Plot of tanh Open FLAMINGO implementation Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
  • 13. F L A M I N G O A R C H I T E C T U R E G A T E D X - A T T E N T I O N - M A S K I N G • Text is processed by replacing vision data (e.g. images) with an <image> tag and the text is divided into chunks using the <EOC> (end of chunk) tag. Each chunk contains at most one image, which is always at the start of the chunk signifying that the subsequent text is assumed to relate only to that image. The beginning of a sequence is also marked with the <BOS> (beginning of sentence) tag. • The sequence is now tokenized; each token is given an index (𝜙) corresponding to the index of the preceding image. 0 means no preceding image. • Now, using 𝜙 a mask is created that ensures the text tokens (𝑌) can only cross-attend to the image tokens (𝑋) that correspond to their chunk. The mask is illustrated by dark blue entries (unmasked/visible) and light blue entries (masked). • For cross-attention, 𝑄𝑢𝑒𝑟𝑦 = 𝑌, 𝐾𝑒𝑦 = 𝑉𝑎𝑙𝑢𝑒 = [𝑋] Open FLAMINGO implementation Source (paper): Flamingo: a Visual Language Model for Few-Shot Learning
  • 14. F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 1 / 7 ) • Concept of Mixin is used to add functionality of gated cross-attention to a frozen language model. • Extend_instance() is an implementation of mixin in python. • obj is the object whose functionality we want to extend. • mixin is the class whose functionality we want to add to obj. • Using the type(), a new class is dynamically created that inherits from both the base class and the mixin using the 2nd argument. The 3rd argument is for any additional attributes for the new class. • Here, the extend_instance() is used to add gated cross-attention functionality between vision and language inputs: Let’s understand FlamingoLMMixin now! Open FLAMINGO implementation Code: mlfoundations/open_flamingo: An open-source framework for training large multimodal models. (github.com)
  • 15. F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 2 / 7 ) Open FLAMINGO implementation MaskedCrossAttention Implements cross- attention between text and image with masking GatedCrossAttentionBlock Adds tanh gating over the masked cross attention FlamingoLMMixin (nn.Module) pseudo code init_flamingo_layers() #causal LLM decoder layers decoderBlocks = List[CausalLlmDecoderBlock] gatedXAttentionBlocks = List[GatedCrossAttentionBlock] For layerIndex in decoderBlocks: If [layerIndex] % cross_attn_every_n_layers: gatedXAttentionBlocks.append(GatedCrossAttentionBlock()) Else: gatedXAttentionBlocks.append(None) flamingoLayers = List[FlamingoLayer] For d,x in zip(decoderBlockers, gatedXAttentionBlocks): flamingoLayers.append(FlamingoLayer(d, x) CausalLlmDecoderBlock.decoderLayers = flamingoLayers forward() calls super.forward() which in this case will CausalLlm (decoder) model’s forward(). As the CausalLlm (decoder) model’s layers are replaced with FlamingoLayer, it’s forward() will be called. FlamingoLayer Stores both GatedCrossAttentionBlock and decoderBlock for a layer
  • 16. F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 3 / 7 ) Open FLAMINGO implementation FlamingoLayer (nn.Module) pseudo code self.gated_x_attention:GatedCrossAttentionBlock self.decoder_layer: original LLM Decoder Layer condition_vis_x(vision_features) vision features from the vision encoder self.vis_x = vision_features Condition_media_locations(media_locations) input ids where input id == input id of <image> token. self.media_locations = media_locations forward(lang_x:[B, L, txt_dim], attention_mask: [B, L]): As cross attention layers are added every n layers based on config, if gated cross attention is not none, call it first, then, call the LLM decoder layer, otherwise directly called LLM decoder layer. if self.gated_x_attention is not None: lang_x = self.gated_x_attention(lang_x, self.vis_x, self.media_locations) # lang_x.shape = [B, L, txt_dim] lang_x = self.decoder_layer(lang_x, attention_mask) return lang_x #lang_x.shape = [B, L, txt_dim] FlamingoLayer Stores both GatedCrossAttentionBlock and decoderBlock for a layer The LLM decoder layers are replaced with FlamingoLayer by FlamingoMixin mixin.
  • 17. F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 4 / 7 ) Open FLAMINGO implementation MaskedCrossAttention (nn.Module) pseudo code Init() num_heads #number of attention heads head_dim #dimensionality of each attention head vis_dim, #visual features dimensionality dim #input and output dimension of this module. forward(text_features:[B,L,txt_dim], visual_features[B,T,G,vis_dim], visual_locations_in_text:bool[B,L]) text_features = layerNorm(text_features) # [B, L, txt_dim] hidden_dim = num_heads * head_dim q = LinearProjection(text_features) → dim to hidden_dim #[B,L,hidden_dim] visual_features = rearrange(visual_features, (B,(T * G), vis_dim)) # [B,(T*G), vis_dim] k = LinearProjection(visual_features) → vis_dim to hidden dimension → #[B,T*G,hidden_dim] v = LinearProjection(visual_features) → vis_dim to hidden dimension → #[B,T*G,hidden_dim] rearrange(q,k,v,from: (B,n,hidden_dim), to: (B,num_heads,n,head_dim)) now, q,k,v has shape [B,num_heads,n,head_dim] where n = L or T*G for text and image resp. score = 𝒒 ∗ 𝒌𝑻 # (B, num_heads,L,head_dim) * (B,num_heads,head_dim,T*G) → shape: [B, num_heads, L, T*G] # now we want to nullify some scores based on which image index (T) [<=T all past images up to current text token # time or =T only immediate last image to the current text token] should be considered for scoring or not. media_time = torch.arange(T) + 1 # let’s say T = 2; media_time = [1,2] # let’s say, L=5,B=2 # visual_locations_in_text = tensor([[True,True,True,False,True], [False,False,True,False, True]]) text_time = visual_locations_in_text.cumsum(dim=-1) # cumulative summation text_time = tensor([[1, 2, 3, 3, 4], [0, 0, 1, 1, 2]]) MaskedCrossAttention Implements cross- attention between text and image with masking
  • 18. F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 5 / 7 ) Open FLAMINGO implementation MaskedCrossAttention Implements cross- attention between text and image with masking GatedCrossAttentionBlock Adds tanh gating over the masked cross attention MaskedCrossAttention (nn.Module) pseudo code … forward(text_features:[B,L,txt_dim], visual_features[B,T,G,vis_dim,visual_locations_in_text:bool[B,L]) … # prepare for comparison of text token and image token along time dimension re_text_time = einops.rearrange(text_time, "b i -> b 1 i 1") # re_text_time = tensor([[[[1],[2],[3],[3],[4]]], [[[0], [0], [1], [1], [2]]]]) # re_text_time.shape = [2, 1, 5, 1] # say, T = visual_features[2] = 8 # i.e. G re_media_time = einops.repeat(media_time, "j -> 1 1 1 (j n)", n=visual_features[1]) # media_time = [1,2] # re_media_time = tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]]]]) # re_media_time.shape = [1, 1, 1, 16] if only immediate preceding image (index T) should be allowed text_to_media_mask = torch.eq(re_text_time, re_media_time) else #for all past images (index T) i.e. as long as text token position is greater than the media one. text_to_media_mask = torch.ge(re_text_time, re_media_time) # text_to_media_mask.shape =[2, 1, 5, 16]; generically, text_to_media_mask.shape = [B,1,L,T*G] # when operator is torch.eq; showing 1 entry for brevity – 16 entries; each entry in re_text_time (here, [1]) is compared (equality) with all the 16 entries in re_media_time (here, [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2]) using broadcasting; so,for 1st entry: 1 eq 1 = True # text_to_media_mask = tensor([[[[True,True,True,True,True,True,True,True, False,False,False,False,False,False,False,False], … ]]]) # when operator is torch.ge; showing 1 entry for brevity – 16 entries; each entry in re_text_time (here, [2]) is # compared(ge) with all the 16 # entries in re_media_time (here, [1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2]); so, for 1st entry: 2 ge 1 = True # text_to_media_mask = tensor([[[[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], … ]]])
  • 19. F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 6 / 7 ) Open FLAMINGO implementation MaskedCrossAttention (nn.Module) pseudo code … forward(text_features:[B,L,txt_dim], visual_features[B,T,G,vis_dim,visual_locations_in_text:bool[B,L]) … score.shape = [B, num_heads, L, T*G] text_to_media_mask.shape = [B, 1, L, T*G] # now, update the score with a large negative number wherever the mask is False. # we use masked_fill() which fills based on wherever the mask is True; hence, we invert the mask with ~ score = score.masked_fill(~text_to_media_mask, -torch.finfo(score.dtype).max) score = score - score.amax(dim=-1, keepdim=True) attn = score.softmax(dim=-1) if exists(visual_locations_in_text) and only attend immediately preceding image # if text does not have any preceding image, then attention score should be zeroed. text_without_media_mask = text_time == 0 text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1") attn = attn.masked_fill(text_without_media_mask, 0.0) #attn.shape = [B, num_heads, L, T*G] #v.shape = [B,num_heads,T*G, head_dim] out = attn * 𝒗 #out.shape = [B, num_heads, L, head_dim] reshape out → [B, L, (num_heads * head_dim)] out = LinearProjection(out) to dim → out.shape = [B, L, dim] return out
  • 20. F L A M I N G O A R C H I T E C T U R E : G A T E D X - A T T E N T I O N ( 7 / 7 ) Open FLAMINGO implementation GatedCrossAttentionBlock (nn.Module) pseudo code init() self.crossAttention = MaskedCrossAttention() self.attn_gate = LearnableParameter(initial Value = 0.0) self.feedforward = layerNorm(dim(x))→ Linear(dim(x), 4*dim(x))→ GeLU(x), Linear(4*dim(x), dim(x))) self.feedforward_gate = LearnableParameter(initial Value = 0.0) forward(text_features:[B,L,txt_dim], visual_features[B,T,G,vis_dim,visual_locations_in_text:bool[B,L] crossAttentionScores = self.crossAttention(text_features, visual_features, visual_locations_in_text) # crossAttentionScores.shape = [B, L, txt_dim] # tanh() to control the magnitude of cross attention scores – beginning with no cross-attention scores # as tanh(0.0) = 0.0 so text features from frozen LLM remains as is. However, with more epochs, the # parameter value attn_gate adapts to control the magnitude. crossAttentionScores = crossAttentionScores * math.tanh(self.attn_gate) # [B, L, txt_dim] x = text_features + crossAttentionScores # [B, L, txt_dim] x = self.feedforward(x) # [B, L, txt_dim] # Like attn_gate however, for the feedforward MLP here. x = (x * math.tanh(self.feedforward_gate)) + x # [B, L, txt_dim] return x # [B, L, txt_dim] Plot of tanh
  • 21. F L A M I N G O A R C H I T E C T U R E : F L A M I N G O Open FLAMINGO implementation Flamingo (nn.Module) pseudo code init() self.visionEncoder = OpenCLIP.VisionTower self.perceiver = PerceiverResampler #see Perceiver Resampler slide self.llm = CausalLlm #any decoder style causal LLM self.llm = extend_instance(self.llm, FlamingoMixin) #see FlamingoMixin slide self.Llm.init_flamingo_layers(…) #see FlamingoMixin slide self.media_token_id = token Id of <Image> forward(text_inputIds:[B,L], image_inputs:[B, T_img, F=1, C, H, W]): # Images in the same chunk are collated along T_img i.e., # of images in a text chunk <|endofchunk|>. # frames are collated along F; Currently only F=1 is supported image_inputs = rearrange(image_inputs, “B T_img F C H W -> (B T_img F) C H W") vision_features = self.visionEncoder(image_inputs) #tokens output - see CLIP vision tower slide # vision_features.shape = [(B * T_img * F),(GxG)+1, D] vision_features = rearrange(vision_features, "(B T_img F) (GxG)+1 D -> B T_img F (GxG)+1 D") perceiver_features = self.perceiver(vision_features) # [B, T_img, #latents, D] for layer in llm.get_decoder_layers(): # see FlamingoMixin slide – as it replaces each decoder layer with FlamingoLayer # so, each layer is an instance of FlamingoLayer layer.condition_vis_x(perceiver_features) #see FlamingoLayer slide visual_locations_in_text = text_inputIds == self.media_token_id layer.condition_media_locations(visual_locations_in_text) #now, call CausalLlm.forward() which in turn will call FlamingoLayer.forward() for each LlM decoder layer. output = self.Llm(text_inputIds,…) return output # output.shape = [B, L, txt_dim]
  • 22. FLAMINGO TRAINING INFO • Dataset: LAION-2B, MMC4, synthetic: ChatGPT • Objective function: ෍ 𝒎=𝟏 𝑴 𝝀𝒎 𝔼 𝒙,𝒚 ~𝑫𝒎 [ − ෍ 𝒍=𝟏 𝑳 𝐥𝐨𝐠 𝒑 𝒚𝒍 𝒚<𝒍, 𝒙<𝒍)] • Tuning the per-dataset weights 𝜆𝑚 is key to performance. • All models trained using 64 GPUs distributed across 8 nodes using either Distributed Data Parallel (DDP) or FSDP (Fully Sharded Data Parallel) techniques. 𝐷𝑚 = m-th dataset 𝜆𝑚 = dataset weightage L = sequence length 𝑦 = text token 𝑥 = image token Open FLAMINGO implementation Paper - OpenFlamingo: An Open-Source Framework for Training Large Autoregressive Vision-Language Models Negative Log likelihood loss i.e., minimize the weighted sum of per-dataset (𝐷𝑚) expectation of negative log-likelihoods of text token (𝑦𝑙) conditioned on text (𝑦<𝑙) and visual inputs (𝑥<𝑙) preceding the current text token. minimize All openFlamingo models are trained with CLIP ViT-L/14 as vision encoder
  • 23. Flamingo OpenCLIP F L A M I N G O A R C H I T E C T U R E – P U T T I N G I T T O G E T H E R Open FLAMINGO implementation ViT Causal LLM Constrastive learning Objective ViT Perceiver Resampler GatedCrossAttentionBlock MaskedCrossAttention Decoder Layer Decoder Layer Decoder Layer Decoder Layer Next token Prediction objective [CrossEntropyLoss] a sal LL e oder Layer e oder Layer e oder Layer e oder Layer Tokenizer Input webpage Processed to FlamingoLayer
  • 24. THANK YOU 8/06/20XX PITCH DECK 24 Source: Announcing OpenFlamingo: An open-source framework for training vision- language models with in-context learning | LAION