実装レベルで学ぶVQVAE
ぱん@かーねる
⽬標: 実装レベルでVQ-VAEを理解する
n 第⼀著者: Aaron van den Oord
n 同著者が書いた関連論⽂
n Neural Discrete Representation Learning (NIPS 2017)
n Generating Diverse High-Fidelity Images with VQ-VAE-2 (NIPS 2019)
n 概要:
n VAEのフレームワークで離散的な潜在変数の学習を可能にし,posterior collapseの問題を解決する
ことで,⾼いクオリティの画像,ビデオ,⾳声のサンプリングを可能にした
VQVAE2による256x256のサンプル
提案⼿法: VQVAEの学習⽅法
n 1: 例えば32x32x3の画像をCNNでエンコードして,8x8xDのfeature mapを出⼒する
n 2: feature mapのそれぞれの1x1xDのベクトルに最も距離が近いものを,予め⽤意したK個の
D次元の埋め込みベクトルに置き換える
n 3: 置き換えた8x8xDのベクトルをデコードして元の画像を復元できるように学習する
提案⼿法: VQVAEの学習⽅法
n 1: 例えば32x32x3の画像をCNNでエンコードして,8x8xDのfeature mapを出⼒する
n 2: feature mapのそれぞれの1x1xDのベクトルに最も距離が近いものを,予め⽤意したK個の
D次元の埋め込みベクトルに置き換える
n 3: 置き換えた8x8xDのベクトルをデコードして元の画像を復元できるように学習する
提案⼿法: VQVAEの学習⽅法
n 1: 例えば32x32x3の画像をCNNでエンコードして,8x8xDのfeature mapを出⼒する
n 2: feature mapのそれぞれの1x1xDのベクトルに最も距離が近いものを,予め⽤意したK個の
D次元の埋め込みベクトルに置き換える
n 3: 置き換えた8x8xDのベクトルをデコードして元の画像を復元できるように学習する
提案⼿法: VQVAEの学習⽅法
n 学習するもの: エンコーダ・デコーダのパラメータ, KクラスxD次元の埋め込みベクトル
n sgはstop gradientで勾配を計算しないの意味
n 再構成の際,埋め込みベクトルに流れた勾配をそのままエンコーダに渡して学習させる
再構成誤差 埋め込みベクトルを
エンコーダベクトル
に近づける
エンコーダベクトルに
埋め込みベクトルを
近づける
提案⼿法: VQVAEのサンプリング⽅法
n 潜在変数の形はベクトルではなく,例えば32x32x3の画像の場合,8x8xDのテンソルになる
n この潜在変数の規則性のようなものを学習するために,PixelCNNを新たに学習させる
n PixelCNNで潜在変数をサンプリングして,これをデコードすることで画像のサンプリング
を⾏う
学習 サンプリング
実験: VQVAEのサンプリング
n 左図: ⼀様分布の潜在変数の事前分布からサンプリングした画像
n 右図: PixelCNNの事前分布からサンプリングした画像
Uniform prior PixelCNN prior
そこそこ
きれい
実験: VQVAEの再構成
n 保持している情報が多い(特に位置)ので,再構成は⾮常に綺麗
考察: なぜ再構成誤差だけで学習可能なのか
n 通常のVAEと同様に,ELBOの最⼤化を考える
n ここで,VQVAEの事後分布は,決定的なone-hotの分布
n また,p(z)が⼀様な事前分布と仮定しているので,KL項は⼀定になり,
学習時にKL項を無視してもよいことになる
実装: ⼤きな流れ
n 表記
n B: バッチサイズ
n C: チャンネル数
n H: ⾼さ
n W: 幅
n K: 埋め込みベクトルの数
n D: 埋め込みベクトルの次元
n 今回はpixelCNNを使ったサンプリング
を除く実装までの解説を⾏う
実装参考
https://nbviewer.jupyter.org/github/zalandoresea
rch/pytorch-vq-vae/blob/master/vq-vae.ipynb
実装: エンコーダ部分(⼀例)
n 2回のconvで画像の⼤き
さを1/4にする
n 次のconvとresblockは⼤
きさとチャネルの数を変
えない
実装: デコーダ部分(⼀例)
n 最後のconv_transは2回
繰り返して画像サイズを4
倍にする
n 最初のconvとresblockは
チャネル数と⼤きさを変
えない
実装: Residual部分
n Resblockのconvは⼊⼒次
元と出⼒次元を同じにす
る
n そうしないともとの⼊⼒
に加えられない
n ここではin_channels =
num_hiddensにする
n layerをリストにするとき
はnn.ModuleListを使う
n 理由は普通のリストを使
うと学習可能なパラメー
タを更新できないから
実装: VQ部分
n ⼊⼒shapeをBCHWからBHWCにして,D
次元のベクトルが並ぶようにflatにする
n embeddingとエンコードベクトルの距離計
算はループを使わない賢いテクニックを
使っている
n それぞれのエンコードされたベクトルと埋
め込みベクトルの⼆乗距離を分解している
実装: VQ部分
n encoding_indices: (B*W*H, 1)
n 距離で⼀番近い部分をとる
n encodings: (B*W*H, K)
n 0の⾏列をつくる
n encodings.scatter_(1, encoding_indices, 1)
n 1番⽬のaxisでインデックス番号の0を1に変換する
(one hotになる)
n quantized: (B, H, W, D)
n エンコードベクトルを埋め込みベクトルにする
quantized
encodings
encoding
indices
実装: VQ部分
n ロスの⼀部分を計算する
n inputs: z_e(x), quantized: e
n detachでstop gradientできる
n その次のコードは,⼊⼒に勾配を伝えるため
n 埋め込みベクトルに置き換えると勾配が⼊⼒に
伝わらなくなるから
実装: VQ部分
n 特に,ロス関数の⼆番⽬の項はEMAを利⽤する
と収束が早い
n ⼀般には,各埋め込みベクトル𝑒"から最も近い
エンコードされたベクトルを𝑧"とし,その数を
𝑛"とすると, 𝑒"は𝑧"の平均をとればよい
n しかし,ミニバッチによる計算を⾏っているた
め,移動平均を利⽤したほうが良い
n ガンマはハイパラで,0.99くらい
実装: 最終的なモデルの訓練
n pre_vq_convでは,VQ部分の⼊⼒チャ
ネルをD次元にするために1x1convをか
ましている
n 訓練は先程説明したVQのlossに再構成
誤差を加えたものを最⼩化する
提案⼿法: VQVAE2
n VQVAEは256x256のような⾼解像度画像のサンプリングは⼗分にきれいにできない
n VQVAE2はVQVAEを階層化することによって,これを解決する
提案⼿法: VQVAE2の訓練アルゴリズム
n ロス関数はVQVAEと同じで,エンコーダが⼀つ増
えるだけ
提案⼿法: VQVAE2のサンプリングアルゴリズム
n まず,topの潜在変数とbottomの潜在変数
を利⽤し,2つのpixelCNN(𝑝&'(, 𝑝*'&&'+)
を学習する
n 次に,学習したpixelCNNを使って潜在変
数をサンプリングし,デコーダに通す
実装: VQVAE2
n 構造は少し複雑になるが,Encoderが⼀つ
増えるだけで,VQVAEと同じように画像
の通りに実装するだけで良い
n 元論⽂では,bottom levelの潜在変数のサ
イズが元画像の1/4で,top levelの潜在変数
のサイズは元画像の1/8
n ⼊⼒が2つあるときはサイズをあわせて
channel⽅向でconcatする

実装レベルで学ぶVQVAE