VICTOR MAESTRE RAMIREZ - Planetary Defender on NASA's Double Asteroid Redirec...
Transformers for Image Recognition: ViT Architecture Explained
1. An Image is Worth 16x16 Words: Transformers for
Image Recognition at Scale
2022/05/02, Changjin Lee
2. Introduction
● Transformer has become the de-facto standard for NLP tasks but was not mature in
computer vision tasks (by the time ViT was published)
● Many proposed ways of integrating transformer into computer vision tasks
○ Partially replacing CNN layers with transformer blocks
○ Conjunction with CNN
● ViT suggests fully transformer-based architecture for image classification
14. ViT Architecture Overview
[1] Split into patches & flatten
[2] Linear Projection
[3] Class Tokens
[4] Position Embedding
[5] Encoder block 1 (MHA)
[6] Encoder block 2 (MLP)
x L
[7] Classification Head
Image: (B, C, H, W)
1. Patch Embedding
2. Transformer Encoder
3. Classification Head
15. Flatten & Linear Projection
1. Split into patches and flatten
2. Linear projection to get embedded vector
a. Trainable linear projection
…
C
P
P
Flatten [1] Split into patches & flatten
[2] Linear Projection
[3] Class Tokens
[4] Position Embedding
[5] Encoder block 1 (MHA)
[6] Encoder block 2 (MLP)
[7] Classification Head
16. Class Token
● Prepend a learnable embedding, class token, to the sequence of embedded patches
● The class token at the output of the encoder serves as the image representation
[1] Split into patches & flatten
[2] Linear Projection
[3] Class Tokens
[4] Position Embedding
[5] Encoder block 1 (MHA)
[6] Encoder block 2 (MLP)
[7] Classification Head
17. Position Embedding
● Plain transformer does not contain relative ordering information of the patches
● Learnable 1D position embedding
[1] Split into patches & flatten
[2] Linear Projection
[3] Class Tokens
[4] Position Embedding
[5] Encoder block 1 (MHA)
[6] Encoder block 2 (MLP)
[7] Classification Head
19. Encoder Block 1: Multi-head Attention
● Vectorized Implementation
[1] Split into patches & flatten
[2] Linear Projection
[3] Class Tokens
[4] Position Embedding
[5] Encoder block 1 (MHA)
[6] Encoder block 2 (MLP)
[7] Classification Head
20. Encoder Block 2: MLP
● GELU non-linearity
● Dropout
● Expansion
[1] Split into patches & flatten
[2] Linear Projection
[3] Class Tokens
[4] Position Embedding
[5] Encoder block 1 (MHA)
[6] Encoder block 2 (MLP)
[7] Classification Head
21. Classification Head
● Classification head is attached for the final prediction
[1] Split into patches & flatten
[2] Linear Projection
[3] Class Tokens
[4] Position Embedding
[5] Encoder block 1 (MHA)
[6] Encoder block 2 (MLP)
[7] Classification Head
22. ViT: Putting all together [1] Split into patches & flatten
[2] Linear Projection
[3] Class Tokens
[4] Position Embedding
[5] Encoder block 1 (MHA)
[6] Encoder block 2 (MLP)
[7] Classification Head
23. Transformer(ViT) needs a lot of data!
● For smaller dataset, ViT performs worse than ResNet-based models!
larger dataset
larger dataset
But.. WHY?
24. Inductive Bias
● Inductive bias is any assumption we make about the unseen data
● House price prediction
○ Features: house size, # floors, # bedrooms
○ Model 1: Plain MLPs with billions of parameters (no assumption)
■ Needs TONS of data to figure out the underlying relationship from scratch
○ Model 2: Linear regression
■ We “assume” that the features are related to the house price linearly
■ If our assumption is correct -> more efficient!
● Relational Inductive Bias
○ Represents relationships between entities in the network (ex. entities=pixel)
25. CNN - Relational Inductive Bias
● Locality
○ Use kernels which capture local relationships between the entities in the kernel
● 2D Neighborhood Structure
● Translation Equivariance: input changes, output changes
● Translation Invariance: input changes, output doesn’t change
● Good for image-related tasks!
Translation Equivariance
26. Transformer - Inductive Bias
● Transformer has a weak image-specific inductive bias
● In ViT, only MLP are local and translation equivariance
● Self-attention is global!
● 2D neighborhood structure is only used when
○ Image is cut into patches
○ Position embeddings
● This weak inductive bias leads transformer to need extensive dataset to learn about the 2D
positions of patches and all spatial relations between the patches from scratch
● With small~medium datasets, ViT performs worse than CNNs but with large datasets, ViT
outperforms CNNs