Transformer解説
~Chat-GPTの源流~
1
Chat-GPTを理解したい
Chat-GPTすごい
APIが公開され、活用アプリ&怪しい
記事が爆増
CSer, ISerとして、
根底から理解しよう
2
あくまで私は計算機屋さん
細かい理論についてはわからん
ü大規模言語モデルのお気持ちには触れつつ、
あくまで、その計算フローに焦点を当てる
ü本資料では学習方法は深く取り扱わない
3
⽬次
• ざっくり深層学習
• 自然言語処理モデル
• RNN
• Transformer
4
深層学習が何をやってるか
深層学習
…複雑な関数を、単純な線形変換()を大量に重ねて近似すること
5
Rabbit
Kawauso
Cat
function(DNN)
単純な線形変換()
𝑦 = 𝐴𝑐𝑡𝑖𝑣𝑎𝑡𝑒(𝑥𝑊 + 𝑏)
(以降こいつをLinearって呼びます)
Øこいつの重ね方に工夫が生まれる
Ø𝑊, 𝑏の値をよしなに調整するのが学習
6
𝑥
𝑊
𝑏
Activate
Func
etc.
𝑦
非線形性を生み、
表現力が向上
*
⾃然⾔語処理
人間の言語の解釈を要するタスクを機械に解かせたい
Ø曖昧性の高さから、深層学習によるアプローチが主流
• 文章要約
• Q & A
• 翻訳
7
Chat-GPTへの道のり
BERT
8
GPT GPT-2 GPT-3 Chat-GPT
Transformer
全てはTransformerから始まった
Øまずはコイツから始めましょう!
ざっくりTransformer
“Attention Is All You Need”
(Ashish Vaswani @Google Brain et al. )
機械翻訳用の自然言語モデル
従来のRNNベースの手法から大幅に性能改善
Ø自然言語処理のbreak throughを作った革命的なモデル
9
Decoder
Encoder
翻訳の主流︓Encoder-Decoderモデル
10
DNN
単語ベクトル群
I am a man .
文の意味っぽいベクトル
DNN
私 は 人 だ 。
文の意味っぽいベクトル
単語ベクトル群
機械翻訳の祖︓RNN
Recurrent Neural Network
… 入力長分、共通の線形変換()を繰り返し適用するモデル
可変長の入力に対応可能、系列データ全般に強い
11
Linear Linear Linear Linear Linear
x0 x1 x2 x3 x4
y0 y1 y2 y3 y4
s0 s1 s2 s3 s4
Encoder
RNNで機械翻訳 - Encoder部分
12
I have a pen .
Linear Linear Linear Linear Linear
x0 x1 x2 x3 x4
s0 s1 s2 s3
文
全
体
の
意
味
Word Embedding(実はこいつもDNN)
Decoder
RNNで機械翻訳 - Decoder部分
13
私は ペンを 持って いる 。
Linear Linear Linear Linear Linear
y0 y1 y2 y3 y4
t0 t1 t2 t3
文
全
体
の
意
味
Word Embedding(実はこいつもDNN)
Decoder
Encoder
RNNの問題点
• 計算フローのクリティカルパスが文の長さに比例
ØGPU等の並列計算で高速化できない
14
14
I have a pen .
Linear Linear Linear Linear Linear
x0 x1 x2 x3 x4
s0 s1 s2 s3 文
全
体
の
意
味
Word Embedding(実はこいつもDNN)
私は ペンを 持って いる 。
Linear Linear Linear Linear Linear
y0 y1 y2 y3 y4
t0 t1 t2 t3
文
全
体
の
意
味
Word Embedding(実はこいつもDNN)
Transformer
並列性の高い計算フローを持つ
Encoder-Decoder型DNN
主要なパーツ
• Positional Encoding
• Feed-Forward Network
• Layer Normalization
• Multi-Head Attention
15
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
まずは超ざっくり⾒る
1. 入力文をEncode
2. 出力済の文と1の結果から、
次単語の確率分布を生成
3. ビームサーチで次単語確定、
出力済の文に追加
4. 2に戻る
16
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Output Probabilities
+
Layer Norm
Encoder
Decoder
I am a boy .
⼊⼒⽂の意味
私は
男の子 男性 女の子 犬
80% 10% 6% 4%
Transformer
主要なパーツ
• Positional Encoding
• Feed-Forward Network
• Layer Normalization
• Multi-Head Attention
17
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
Positional Encoding
文の意味解釈で、各単語の位置情報は重要
LLinear層は単語の順序を考慮しない
Ø入力時点で、単語自体に位置情報を明示的に埋め込む必要性
18
I am a boy .
Word Embedding
単
語
ベ
ク
ト
ル
pos
i
𝑃𝐸 𝑝𝑜𝑠, 2𝑖 = sin
!"#
$%%%%
!"
#
𝑃𝐸 𝑝𝑜𝑠, 2𝑖 + 1 = cos(
!"#
$%%%%
!"
#
)
𝑑
Transformer
主要なパーツ
• Positional Encoding
• Feed-Forward Network
• Layer Normalization
• Multi-Head Attention
19
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
Feed-Forward Network
𝑧 = 𝑅𝑒𝐿𝑈 𝑥𝑊! + 𝑏!
𝑦 = 𝑧𝑊" + 𝑏"
ØLinear x2。それだけ
20
𝑥 𝑊! 𝑧
𝑧
𝑊"
𝑦
※bは省略
Transformer
主要なパーツ
• Positional Encoding
• Feed-Forward Network
• Layer Normalization
• Multi-Head Attention
21
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
Layer Normalization
𝐿𝑁 𝑥# =
𝑥# − 𝜇
𝜎
𝛾 + 𝛽
𝜇, 𝜎: 𝑥#の平均, 標準偏差
𝛾, 𝛽: パラメタ(スカラ値)
ただの正規化もどき
Ø学習の高速化や過学習の抑制に寄与
行単位で適用
22
1
…
3
…
2
…
6
…
I am a boy
Layers…
…
…
…
…
-0.9 0 -0.51.4
LN
LN
Transformer
主要なパーツ
• Positional Encoding
• Feed-Forward Network
• Layer Normalization
• Multi-Head Attention
23
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
Positional
Encoding
Positional
Encoding
Multi-Head Attention
𝑴𝒖𝒍𝒕𝒊𝑯𝒆𝒂𝒅(𝑄, 𝐾, 𝑉) = 𝒄𝒐𝒏𝒄𝒂𝒕 ℎ𝑒𝑎𝑑# 𝑊$
ℎ𝑒𝑎𝑑# = 𝑺𝑫𝑷𝑨𝒕𝒕𝒆𝒏𝒕𝒊𝒐𝒏 𝑄𝑊
#
%
, 𝐾𝑊#
&
, 𝑉𝑊#
'
𝑺𝑫𝑷𝑨𝒕𝒕𝒆𝒏𝒕𝒊𝒐𝒏 𝑄′, 𝐾′, 𝑉′ = 𝒔𝒐𝒇𝒕𝒎𝒂𝒙
𝑄′𝐾′(
𝑑
𝑉′
お気持ち
• 𝑉には、整理されていない有益情報がたくさん
• 𝐾は𝑉に紐づく情報がたくさん
• 𝑄に近い情報がKにあれば、対応する有益情報を𝑉から抽出 24
Scaled Dot Product Attention
𝑆𝐷𝑃𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 𝑄, 𝐾, 𝑉 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥
𝑄𝐾(
𝑑
𝑉
お気持ち
• 𝑘#(𝑘𝑒𝑦), 𝑣#(𝑣𝑎𝑙𝑢𝑒)という対を為すベクトルが沢山
• 各入力ベクトル𝑞)と似ているkeyを集める
• keyに対応するvalueたちを混ぜて出力
25
𝑞!
𝑞"
𝑞#
𝑞$
𝑘!
𝑘"
𝑘#
𝑣!
𝑣"
𝑣#
Scaled Dot Product Attention①
𝑆𝐷𝑃𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 𝑄, 𝐾, 𝑉 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥
𝑄𝐾(
𝑑
𝑉
𝑄𝐾(
の各要素は𝑞#と𝑘)の内積
Ø𝑞#, 𝑘)の向きが近いほど値が大きいため、類似度の指標に
(内積はベクトル長に比例してしまうため、 𝑑で割る)
26
𝑞!
𝑞"
𝑞#
𝑞$
𝑘! 𝑘" 𝑘# 𝑘$ 𝑘%
*
𝑑
𝑞! ∗ 𝑘!
𝑞" ∗ 𝑘!
𝑞# ∗ 𝑘!
𝑞! ∗ 𝑘"
𝑞" ∗ 𝑘"
𝑞# ∗ 𝑘"
𝑞! ∗ 𝑘#
𝑞" ∗ 𝑘#
𝑞# ∗ 𝑘#
𝑞$ ∗ 𝑘!𝑞$ ∗ 𝑘"𝑞$ ∗ 𝑘#
𝑞!と各keyとの
類似度ベクトル
Scaled Dot Product Attention②
𝑆𝐷𝑃𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 𝑄, 𝐾, 𝑉 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥
𝑄𝐾(
𝑑
𝑉
• 𝒔𝒐𝒇𝒕𝒎𝒂𝒙 𝒙 = [
𝒆𝒙𝟏
∑ 𝒆𝒙𝒊
,
𝒆𝒙𝟐
∑ 𝒆𝒙𝒊
,
𝒆𝒙𝟑
∑ 𝒆𝒙𝒊
, … ]
Øベクトルを少し過激に確率分布に変換する関数
ex.) 𝑠𝑜𝑓𝑡𝑚𝑎𝑥([2,3,5]) = [0.4, 0.11, 0.85]
27
𝑞! ∗ 𝑘!
𝑞" ∗ 𝑘!
𝑞# ∗ 𝑘!
𝑞! ∗ 𝑘"
𝑞" ∗ 𝑘"
𝑞# ∗ 𝑘"
𝑞! ∗ 𝑘#
𝑞" ∗ 𝑘#
𝑞# ∗ 𝑘#
𝑞$ ∗ 𝑘!𝑞$ ∗ 𝑘"𝑞$ ∗ 𝑘#
softmax
softmax
softmax
softmax
𝑞!~𝑘!
𝑞"~𝑘!
𝑞#~𝑘!
𝑞!~𝑘"
𝑞"~𝑘"
𝑞#~𝑘"
𝑞!~𝑘#
𝑞"~𝑘#
𝑞#~𝑘#
𝑞$~𝑘! 𝑞$~𝑘" 𝑞$~𝑘#
𝑞!と各keyとの
類似性の確率分布
Scaled Dot Product Attention③
𝑆𝐷𝑃𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 𝑄, 𝐾, 𝑉 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥
𝑄𝐾(
𝑑
𝑉
前stepで求めた確率分布を重みと捉え、valuesを加重平均
28
𝑞!~𝑘!
𝑞"~𝑘!
𝑞#~𝑘!
𝑞!~𝑘"
𝑞"~𝑘"
𝑞#~𝑘"
𝑞!~𝑘#
𝑞"~𝑘#
𝑞#~𝑘#
𝑞$~𝑘! 𝑞$~𝑘" 𝑞$~𝑘#
𝑣!
𝑣"
𝑣#
*
[0.4, 0.11, 0.85]
% 𝑞!~𝑘% ∗ 𝑣%
% 𝑞"~𝑘% ∗ 𝑣%
% 𝑞#~𝑘% ∗ 𝑣%
% 𝑞$~𝑘% ∗ 𝑣%
𝑀𝑢𝑙𝑡𝑖𝐻𝑒𝑎𝑑 𝑄, 𝐾, 𝑉 = 𝑐𝑜𝑛𝑐𝑎𝑡 ℎ𝑒𝑎𝑑# 𝑊$
ℎ𝑒𝑎𝑑# = 𝑆𝐷𝑃𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 𝑄𝑊
#
%
, 𝐾𝑊#
&
, 𝑉𝑊#
'
Ø𝑄, 𝐾, 𝑉に様々な変換を加え、組合わせに多様性を持たせている
Multi-Head Attention
29
𝑊!
"
𝑊!
#
𝑄
𝐾
𝑉
𝑊$
"
𝑊$
#
*
*
*
𝑊%
"
𝑊%
#
𝑊
!
&
𝑊
$
&
𝑊
%
&
𝐾!
'
𝑉!
'
𝐾$
'
𝐾%
'
𝑉$
'
𝑉%
'
𝑄!
'
𝑄$
'
𝑄%
'
SDP
Attention
ℎ𝑒𝑎𝑑
!
ℎ𝑒𝑎𝑑
$
ℎ𝑒𝑎𝑑
%
ℎ𝑒𝑎𝑑
!
ℎ𝑒𝑎𝑑
$
ℎ𝑒𝑎𝑑
%
*
𝑊&
Multi-Head Attentionの使われ⽅①
Q,K,Vが全て同じ入力(文)
Ø入力(文)を様々な角度で切り出した物同士
を見比べ、注目すべき箇所を決めて出力
30
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
I could not put the violin in the bag because it was too big.
𝑊%
"
𝑊%
#
𝑊
%
&
it was too big
SDP Attention
not because violin bag
put in the bag the violin
the violin was too big
Q K V
勝手なイメージ
Multi-Head Attentionの使われ⽅②
𝑄: 加工済み出力文
𝐾, 𝑉: encoderの出力
Ø出力文から、入力文のどの意味がまだ不足
しているか等を判断している?
31
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
Masked Multi-Head Attention
主に学習のための機構
学習時は入力文と出力文の模範解答を流す
次単語予測の正解がわからないように、
出力文を一部maskするだけ
32
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
パーツ理解完了
最後に流れを再確認して
締めましょう
33
Transformer
34
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N
Positional
Encoding
Positional
Encoding
次回予告
transformerは本来翻訳家
だが、意味解釈能力が超凄い
これ、何にでも応用できる?
ØGPTs, BERT
35
Masked Multi-Head
Attention
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
Linear
softmax
Output Embedding
+
+
+
〜
Input Embedding
Multi-Head
Attention
Layer Norm
Feed Forward
Layer Norm
+
+
+
〜
Outputs
Inputs
Output Probabilities
+
Layer Norm
N x
x N

transformer解説~Chat-GPTの源流~