1
DEEP LEARNING JP
[DL Papers]
http://deeplearning.jp/
“Efficiently Modeling Long Sequences with Structured State Spaces”
Naoki Nonaka
2021/12/3
書誌情報
2021/12/3 2
• 会議:ICLR2022 投稿(評価: 8, 8, 8)
(本スライドはArxivに投稿されている論文に基づいて作成)
• 著者:
概要
2021/12/3 3
 長距離の依存関係を持つ系列データの問題に取り組んだ研究
 SSM(状態空間モデル)x Deep Learningのアプローチを提案
 ベンチマークにて既存手法を大幅に上回る性能を実現
背景
2021/12/3 4
長距離の依存関係(Long-range dependencies: LRD)
…
依存関係
T
 実世界のデータでは,数万ステップでの推論が必要
(具体例としては,音声や言語情報など)
 LRDに取り組んだ深層学習による従来の手法としては,
RNN, CNNやTransformerとその改良手法が提案されてきた
背景
2021/12/3 5
O X
RNN
ステップごとの
計算量/ストレージが一定
学習に時間がかかる
最適化が難しい(Vanishing gradient)
CNN 並列可能で高速に学習できる
逐次学習ではないので
推論時のコストが高い/扱える長さに制限
LRDに取り組んだ従来手法の利点と欠点
(Transformer系の手法もCNNとほぼ同じ)
背景
2021/12/3 6
理想的な時系列モデル
 各時刻における状態を保持し,推論が可能(recurrence)
 並列計算による学習が可能(convolutional)
 任意の時間軸適応(微分方程式の性質)
状態空間モデル(State Space Model; SSM)
背景
2021/12/3 7
状態空間モデル
 入力,出力,状態の3つの変数からなる数学的モデル
 多くの数理モデルの基礎となっているモデル
状態空間モデル x 深層学習の手法は存在しなかった※
※ 厳密には同一著者の先行研究[1]が不完全ながら取り組んでいる
図は[1]をもとに改変
提案手法: S4
2021/12/3 8
S4: Structured State Space sequence model
→ 状態空間モデル x 深層学習の手法
1. SSMのRecurrent表現とConvolution表現の導出
2. HiPPO行列による連続時間記憶の問題の解決 ※
3. SSM convolutionカーネル(後述)の計算の効率化
※ 同一著者の先行研究[1]における工夫と同じ
S4の導出過程
S4: Recurrent表現とConvolution表現の導出
2021/12/3 9
S4 (SSM): 再帰的な計算と並列学習が可能
連続時間SSM
離散時間SSM
畳み込み演算での表現
RNN様の再帰的な計算が可能に
CNN様の並列計算が可能に
S4: Recurrent表現とConvolution表現の導出
2021/12/3 10
 間隔Δで離散化
 Bilinear法を使用
 離散化により,離散的な入力データを扱えるようになる
 RNNと同じく再帰的な処理が可能になる
離散時間SSM
S4: Recurrent表現とConvolution表現の導出
2021/12/3 11
展開
畳み込み演算での表現
SSM convolution kernel (K) を定義
SSMの畳み込み演算
提案手法: S4
2021/12/3 13
S4: Structured State Space sequence model
→ 状態空間モデル x 深層学習の手法
1. SSMのRecurrent表現とConvolution表現の導出
2. HiPPO行列による連続時間記憶の問題の解決 ※
3. SSM convolutionカーネル(後述)の計算の効率化
※ 同一著者の先行研究[1]における工夫と同じ
S4の導出過程
S4: HiPPO行列による連続時間記憶
2021/12/3 14
 直交多項式の重み付き和によって過去の系列を表現
 RNNに組み込むと記憶性能が向上する
HiPPO: High-order Polynomial Projection Operators
図は[2]より
提案手法: S4
2021/12/3 19
S4: Structured State Space sequence model
→ 状態空間モデル x 深層学習の手法
1. SSMのRecurrent表現とConvolution表現の導出
2. HiPPO行列による連続時間記憶の問題の解決 ※
3. SSM convolutionカーネル(後述)の計算の効率化
※ 同一著者の先行研究[1]における工夫と同じ
S4の導出過程
S4: SSM convolution kernelの計算
2021/12/3 20
SSMの学習の並列化
Aの冪乗計算が必要
連続時間記憶の改善
AはHiPPO行列である必要
HiPPO行列の冪乗計算が必要
S4: SSM convolution kernelの計算
2021/12/3 21
K の計算: 行列Aの冪乗計算を含むため工夫が必要
 Aを,対角行列Λ + 低ランク行列 p, q (rank=1)
 3つの計算工夫を導入
S4: SSM convolution kernelの計算
2021/12/3 22
1. FFTによる冪乗計算の回避 (詳細はAppendix C3)
𝑧におけるSSM母関数を定義
数列𝑎𝑛に対する母関数
𝑓 𝑥 =
𝑘=0
∞
𝑎𝑘𝑥𝑘
𝑧を1の冪根とすると,
1の冪根
 𝜍 = exp
2𝜋𝑖
𝑛
 ある𝑛に対して
𝑧𝑛 = 1を満たす𝑧
→ 離散フーリエ変換と一致
SSM母関数で冪乗計算を逆行列計算化 + 逆FFTで K を得る
S4: SSM convolution kernelの計算
2021/12/3 23
2. 対角行列 + 低ランク行列の逆行列計算
Woodbury恒等式を利用
SSM母関数における逆行列計算を効率化
3. Cauchyカーネルによる計算
Aが対角行列のときSSM母関数の計算 = Cauchyカーネルの計算
Cauchyカーネルの計算アルゴリズムを利用
S4 layer
2021/12/3 24
実装上は,系列を受け取り,系列を出力する層となる
https://github.com/HazyResearch/state-spaces/blob/main/example.py
LayerNorm
Input
S4
Dropout
…
…
__init__ 内 forward 内
実験
2021/12/3 25
 計算効率
 長距離の依存関係の学習
 汎用系列モデルとしての性能
実験: 計算効率
2021/12/3 26
 LSSL(状態空間モデル系の先行研究)よりも高速・高メモリ効率
 (Efficientな)Transformer系と同程度に高速・省メモリ
実験: 長距離の依存関係の学習
2021/12/3 27
 Long Range Arena (LRA)
 (主にTransformer系の手法を念頭にした)
長距離の依存関係のモデリング性能を評価するためのデータセット []
 6つのタスクで構成される
 Raw speech classification
 Speech Commandデータセット(35クラス,100,503件のサンプル)
 話し言葉の音声データの中からキーワードを検出するタスク
実験: 長距離の依存関係の学習 (LRA: 1/4)
2021/12/3 28
1. <LISTOPS> Long ListOps
複数の演算子(MAX, MEAN, MEDIAN, SUM_MOD)の階層構造で
表現された系列から出力となる数字を当てるタスク
2. <TEXT> Byte-level Text classification
 IMDbレビューをもとに作成されたデータセット
 byte/character-levelで分類
実験: 長距離の依存関係の学習 (LRA: 2/4)
2021/12/3 29
3. <RETRIEVAL> Byte-level Document Retrieval
 長い文章を短い表現に圧縮し,文章の類似度を評価するタスク
 元データはIMDbのレビュー
 系列長は4k(長いものはtruncate, 短いものはpadding)
4. <IMAGE> Image Classification on sequence of pixels
 Sequential MNISTのCIFAR-10版
 系列長3072 (= 32 x 32 x 3) のサンプルを10クラスに分類
実験: 長距離の依存関係の学習 (LRA: 3/4)
2021/12/3 30
3. <PATHFINDER> PathFinder
画像中の2点が破線でつながっているか判定
入力:32 x 32の画像の系列(=784)
出力:二値(2点がつながっているか)
4. <PATH-X> PathFinder-X
PathFinderタスクを128 x 128に拡大した画像で実施
実験: 長距離の依存関係の学習 (LRA: 1/4)
2021/12/3 31
 6つのタスク全てで既存手法を大幅に上回る
 PathFinder-Xを解けた唯一のモデル
実験: 長距離の依存関係の学習
2021/12/3 32
 Long Range Arena (LRA)
 (主にTransformer系の手法を念頭にした)
長距離の依存関係のモデリング性能を評価するためのデータセット []
 6つのタスクで構成される
 Raw speech classification
 Speech Commandデータセット(35クラス,100,503件のサンプル)
 話し言葉の音声データの中からキーワードを検出するタスク
実験: 長距離の依存関係の学習(Speech: 1/1)
2021/12/3 33
 MFCCによる前処理あり:先行研究と同程度の性能
 Rawデータでの分類:WaveGANを上回る性能
実験: 汎用系列モデルとしての性能
2021/12/3 34
 大規模な生成モデルの学習
 CIFAR-10における密度推定
 WikiText-103における言語モデリング
 自己回帰による推論
 CIFAR-10およびWikiText-103での生成速度を比較
実験: 汎用系列モデルとしての性能
2021/12/3 35
 先行研究と同程度の性能を達成
 自己回帰による推論の速度は60倍以上高速化
大規模な生成モデルの学習/自己回帰による推論
実験: 汎用系列モデルとしての性能
2021/12/3 36
不規則にサンプリングされたデータの扱い
 Test時のみ周波数を0.5倍にして評価(右列)
 S4では,追加学習なしでも周波数の
変化に対して頑健になっている
結論・まとめ
2021/12/3 37
 状態空間モデルにDNNを取り込んだS4モデルを提案
 LRAにて既存手法を大幅に上回る性能を実現
 汎用系列モデルとしても優れた性能を示す
Reference
2021/12/3 38
1. Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers
2. HiPPO: Recurrent Memory with Optimal Polynomial Projections
Appendix
2021/12/3 39

【DL輪読会】Efficiently Modeling Long Sequences with Structured State Spaces

Editor's Notes

  • #21  K を直接計算せず,母関数の計算 + 逆フーリエ変換に置換 → 行列の冪乗計算を逆行列の計算に変える 逆行列の計算をWoodburyの恒等式により行う Cauchyカーネルの計算に落とし込む
  • #23  K を直接計算せず,母関数の計算 + 逆フーリエ変換に置換 → 行列の冪乗計算を逆行列の計算に変える 逆行列の計算をWoodburyの恒等式により行う Cauchyカーネルの計算に落とし込む
  • #24  K を直接計算せず,母関数の計算 + 逆フーリエ変換に置換 → 行列の冪乗計算を逆行列の計算に変える 逆行列の計算をWoodburyの恒等式により行う Cauchyカーネルの計算に落とし込む