Successfully reported this slideshow.

HiPPO/S4解説

0

Share

Loading in …3
×
1 of 96
1 of 96

HiPPO/S4解説

0

Share

Description

Morpho Tech Bolgより本資料に関する記事をご覧ください。
・「HiPPO/S4解説」:https://techblog.morphoinc.com/entry/2022/05/24/102648
・執筆者:CTO室リサーチャー 角田

・Morpho Tech Blog: https://techblog.morphoinc.com/
・Morpho, Inc.: https://www.morphoinc.com/

Transcript

  1. 1. HiPPO/S4解説 2022/04/22 LSTMやTransformerを超える 時系列モデリングの新手法 株式会社モルフォ 角田良太朗
  2. 2. Copyright © 2022 Morpho, Inc. All Rights Reserved 1 Overview Warning 本日の内容はゴリゴリの理論です。 中身は非常に美しいのですが、なにせ1時間しかないこともあり 結構な脱落者が発生する危険が高いです。 それを踏まえ、通常の流れには逆らい、まず結果から説明します。
  3. 3. Copyright © 2022 Morpho, Inc. All Rights Reserved 2 Overview Long-Range Arena[1]というタスクが存在する。 • Efficient Transformer系の長距離依存性を統一的に評価するためのベン チマーク。 • トークン数1000~16000の様々なテストデータを用意。 • タスクは6種類: • LONG LISTOPS • BYTE-LEVEL TEXT CLASSIFICATION • BYTE-LEVEL DOCUMENT RETRIEVAL • IMAGE CLASSIFICATION ON SEQUENCES OF PIXELS • PATHFINDER • PATHFINDER-X(←次ページで詳述)
  4. 4. Copyright © 2022 Morpho, Inc. All Rights Reserved 3 Overview PATHFINDER-X positive negative 128x128の画像中の2点が点線でつながっているか二値判定。 ただし、 画像はflattenして入力、 2D-Convの使用禁止。 ([6]より引用) ([6]より引用)
  5. 5. Copyright © 2022 Morpho, Inc. All Rights Reserved 4 Overview このタスクで2021年11月に大幅なSOTA更新あり。 ([6]より引用)
  6. 6. Copyright © 2022 Morpho, Inc. All Rights Reserved 5 Overview このタスクで2021年11月に大幅なSOTA更新あり。 S4以外のすべてのモデルは、推論に失敗していた(乱択と同程度) ([6]より引用)
  7. 7. Copyright © 2022 Morpho, Inc. All Rights Reserved 6 Overview 本スライドの目標は、この驚異的なモデルS4とは何者なのかを解明すること。 手法をざっと述べると。。。 ([6]より引用)
  8. 8. Copyright © 2022 Morpho, Inc. All Rights Reserved 7 Overview 1. 時系列解析を状態空間モデルとして次のように定式化 𝑥’(𝑡) = 𝐴𝑥(𝑡) + 𝐵𝑢(𝑡) 𝑦(𝑡) = 𝐶𝑥(𝑡) + 𝐷𝑢(𝑡) (𝑢 𝑡 ∈ ℝ𝐿:input, 𝑥 𝑡 ∈ ℝ𝑁∗𝐿: hidden state, 𝑦 𝑡 ∈ ℝ𝐿: output) A, B, C, D はlearned parameters AはHiPPO matrixとして初期化し、正規行列+low-rank matrixの形のみを 取るよう制限する。
  9. 9. Copyright © 2022 Morpho, Inc. All Rights Reserved 8 Overview 2. この式を離散化&展開することで 𝑦𝑘 = 𝐶𝐴𝑘𝐵𝑢0 + 𝐶𝐴𝑘−1𝐵𝑢1 + ⋯ + 𝐶𝐴𝐵𝑢𝑘−1 + 𝐶𝐵𝑢𝑘 𝑦 = 𝐾 ∗ 𝑢 (𝐾 ≔ 𝐶𝐵 + 𝐶𝐴𝐵 + ⋯ + 𝐶𝐴𝐿−1𝐵 ∈ ℝ𝐿) と1D-Convの形にかける。 𝐾が高速計算できれば学習はRNNより高速。 (𝐿回のiterationをせず一発で計算できるので)
  10. 10. Copyright © 2022 Morpho, Inc. All Rights Reserved 9 Overview 3.Kを直接計算せず、これのスペクトラム 𝐹 𝐾 ≔ ෍ 𝑗=0 𝐿−1 𝐾𝑗𝜁𝑗 を計算してiFFTで𝐾𝑗を一括導出したい。 𝐹(𝐾)を一般に𝐾の𝑧変換として求めることを考える。 これはAを前述の形に制限したことから、 Woodbury-Identityを用いて Cauchy Kernel4つの和として記述でき、特に ෨ 𝑂(𝑁 + 𝐿)で導出可能。 (おわり)
  11. 11. Copyright © 2022 Morpho, Inc. All Rights Reserved 10 Overview 何を言ってるのか全く分からないが、状態空間モデルの方程式を解くのに • 適切な係数空間の設計(HiPPO行列) • convolutionに書き直して高速な行列計算 をしたことで、超長距離依存性を保持できたことがポイントみたい。 実際次で見るように、これらの施策は精度にcriticalに効いている。
  12. 12. Copyright © 2022 Morpho, Inc. All Rights Reserved 11 Overview • 係数行列の初期化の影響(sequential CIFAR10で実験) ([6]より引用)
  13. 13. Copyright © 2022 Morpho, Inc. All Rights Reserved 12 Overview • 高速な行列計算アルゴリズムの影響 ([6]より引用)
  14. 14. Copyright © 2022 Morpho, Inc. All Rights Reserved 13 Overview 高度で緻密な理論設計がここまで見事に精度に反映されるような deep learning modelは見たことがない! これは解読する価値がありそうだ。。。!
  15. 15. Copyright © 2022 Morpho, Inc. All Rights Reserved 14 Overview しかし現段階ではいろいろと謎が多すぎる。 • 唐突に出てきた線形方程式は一体? • HiPPOって何? • 正規行列+low-rankでパラメータを書く動機は? • Woodbury-Identityどこで使うんや? • Cauchy kernelがなんで絡むん? • ・・・・・・
  16. 16. Copyright © 2022 Morpho, Inc. All Rights Reserved 15 Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral)
  17. 17. Copyright © 2022 Morpho, Inc. All Rights Reserved 16 Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral) First author全部同じ人や。。。 Albert Guさん強すぎ。。。
  18. 18. Copyright © 2022 Morpho, Inc. All Rights Reserved 17 Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral) First author全部同じ人や。。。 Albert Guさん強すぎ。。。 今日はこれらを解説していく!! S4を分かるには これ全部理解するしかない
  19. 19. Copyright © 2022 Morpho, Inc. All Rights Reserved 18 Overview Warning いよいよ本番です。 頑張って付いてきてください。 分からなくなったらすぐ声を上げてください。
  20. 20. Copyright © 2022 Morpho, Inc. All Rights Reserved 19 HiPPO HiPPO: Recurrent Memory with Optimal Polynomial Projections (NeurIPS 2020 Spotlight) 正直この論文が一番の肝
  21. 21. Copyright © 2022 Morpho, Inc. All Rights Reserved 20 HiPPO 概要 長時間の時系列データを扱う際はRNNやLSTMがよく用いられるが、 • 数万ステップにもなると記憶が抜けてしまう。 • シーケンス長や時間スケールへの暗黙の依存により、テスト時に汎化しない。 • 理論的な解釈が何となくしか与えられていない。 という問題がある。
  22. 22. Copyright © 2022 Morpho, Inc. All Rights Reserved 21 HiPPO 概要 そこで本論文では • 記憶に関する理論的な定式化を与え、既存手法たちをその枠組みで再解釈。 • シーケンス長や時間スケールに依存しない新手法を提案。 • 新手法の優位性を、提案した枠組みを用いて厳密に証明。 している。
  23. 23. Copyright © 2022 Morpho, Inc. All Rights Reserved 22 HiPPO 手法 そもそも「記憶を保持する」という言葉が既に曖昧。 そこでこれを次のように言い換えるところからスタートする。 「記憶」 = 「時間依存する測度」に基づく「入力信号の多項式近似」 1 2
  24. 24. Copyright © 2022 Morpho, Inc. All Rights Reserved 23 HiPPO ①入力信号の多項式近似 ℝ上の入力信号𝑓: ℝ → ℝを考えよう。 目標は 𝑓 𝑥 𝑥 ≤ 𝑡 までが与えられるので、これから𝑓(𝑥)を推定することである。 t 𝑥 𝑓(𝑥) 0
  25. 25. Copyright © 2022 Morpho, Inc. All Rights Reserved 24 HiPPO ①入力信号の多項式近似 {𝑓(𝑥)|𝑥 ≤ 𝑡}の値をすべて使って推定するのが最善だが、我々はこれをモデルに 学習させたいので、すべての値をモデルに記憶させるのはメモリ的に厳しい。 そこで{𝑓(𝑥)|𝑥 ≤ 𝑡}を何らかの低次元表現に保存することを考える。 𝑡 𝑥 𝑓(𝑥) 0
  26. 26. Copyright © 2022 Morpho, Inc. All Rights Reserved 25 HiPPO ①入力信号の多項式近似 ここでは「直交多項式系」を低次元表現として採用する。 Def. 測度𝜇(𝑥)に対する直交多項式系とは多項式の集合 𝑃𝑛 𝑛=0,1,2,… であって deg 𝑃𝑛 = 𝑛, ⟨𝑃𝑛, 𝑃𝑚⟩𝜇 ≔ න 𝑃𝑛 𝑥 𝑃𝑚 𝑥 𝑑𝜇 = 𝑎𝑛,𝑚𝛿𝑛,𝑚 (∃𝑎𝑛,𝑚 ∈ ℝ) を満たすものを言う。 Rem. ℝ上の測度が与えられれば、直交多項式はスケール𝑎𝑛,𝑚を除き一意的。 これは{1, 𝑥, 𝑥2, … }をグラムシュミット直交化すれば示せる。
  27. 27. Copyright © 2022 Morpho, Inc. All Rights Reserved 26 HiPPO ①入力信号の多項式近似 Ex. ルジャンドル多項式 𝑃𝑛 𝑥 ≔ 1 2𝑛𝑛! 𝑑𝑛 𝑑𝑥𝑛 [ 𝑥2 − 1 𝑛] は測度𝜇 𝑥 ≔ 1 −1,1 𝑥 に対する直交多項式。 直交性はググれば証明出てくる。 𝑃𝑛, 𝑃𝑚 = 2 2𝑛+1 𝛿𝑛,𝑚 なおルジャンドル多項式特有の性質として次がある(後で使う) • 𝑃𝑛 1 = 1, 𝑃𝑛 −1 = −1 𝑛 • 𝑃𝑛 ′ = 2𝑛 − 1 𝑃𝑛−1 + 2𝑛 − 3 𝑃𝑛−2 + ⋯ • 𝑥 + 1 𝑃𝑛 ′ 𝑥 = 𝑛𝑃𝑛 + 2𝑛 − 1 𝑃𝑛−1 + 2𝑛 − 3 𝑃𝑛−2 + ⋯
  28. 28. Copyright © 2022 Morpho, Inc. All Rights Reserved 27 HiPPO ①入力信号の多項式近似 この時𝑓[𝑥≤𝑡}を直交多項式 𝑃𝑛 𝑛=0,1,…,𝑁−1で近似することを考えよう。 直交多項式の次数は適当な𝑁未満で打ち切っていることに注意。 Fact ([2, Theorem3.10, Theorem3.5]) 1. 測度𝜇から定まる直交多項式系 𝑃𝑛 𝑛=0,1,…を固定した時、 任意の関数𝑓 ∈ 𝐿2(ℝ; 𝜇)は以下の級数展開を持つ。 𝑓 = ෍ 𝑛=0,1,… 𝑐𝑛𝑃𝑛 , 𝑐𝑛 ≔ 𝑓, 𝑃𝑛 / 𝑃𝑛, 𝑃𝑛 2. 上記級数を𝑛 = 𝑁 − 1で打ち切ったものを𝑓(𝑛) としたとき、 𝑓(𝑛) = 𝑎𝑟𝑔𝑚𝑖𝑛𝑔∈𝑆𝑝𝑎𝑛⟨𝑃0,…,𝑃𝑁−1⟩ 𝑓 − 𝑔 𝜇
  29. 29. Copyright © 2022 Morpho, Inc. All Rights Reserved 28 HiPPO ①入力信号の多項式近似 気持ちとしては、「𝑓をN次未満直交多項式の張る空間に射影」している。 これにより、話を元に戻すと 𝑓{𝑥≤𝑡} ≈ 𝑐0𝑃𝑜 + 𝑐1𝑃1 + ⋯ + 𝑐𝑁−1𝑃𝑁−1 として過去の信号履歴を(𝑐0, 𝑐1, … , 𝑐𝑁−1)の𝑁変数に圧縮することができた。 さらに上式を使って未来の信号を予測することも可能! 𝑓{𝑥≤𝑡} 𝑆𝑝𝑎𝑛⟨𝑃0, … , 𝑃𝑁−1⟩
  30. 30. Copyright © 2022 Morpho, Inc. All Rights Reserved 29 HiPPO ②時間依存する測度 ではどんな測度(直交多項式)を選ぶのが最適だろうか? • 入力信号𝑓は時間が進むごとに履歴がどんどん積み重なるので、 測度𝜇も時間発展させた方がよいだろう。 そこで以後、時刻𝑡における測度を𝜇 𝑡 と記述する。 注意: 特に以降𝜇 𝑡 のsupportは(−∞, 𝑡]に含まれるものとする。
  31. 31. Copyright © 2022 Morpho, Inc. All Rights Reserved 30 HiPPO ここまでのまとめ • 「記憶」=「入力信号𝑓の過去履歴を𝑁次元直交多項式系{𝑃𝑛}に射影」 • 直交多項式系 𝑃𝑛 は測度𝜇 𝑡 に応じて時間発展させる。 Def. 入力信号𝑓の履歴を直交多項式系に射影して、その係数を取得する操作を ℎ𝑖𝑝𝑝𝑜 𝑓{𝑥≤𝑡} ≔ 𝑐0, 𝑐1, … , 𝑐𝑁−1 𝑤ℎ𝑒𝑟𝑒 𝑓{𝑥≤𝑡} ≈ 𝑐0𝑃𝑜 + 𝑐1𝑃1 + ⋯ + 𝑐𝑁−1𝑃𝑁−1 と書き、HiPPO operatorと呼ぶ。(HiPPO=high-order Polynomial Projection Operator)
  32. 32. Copyright © 2022 Morpho, Inc. All Rights Reserved 31 HiPPO 「記憶」=(𝑐0, 𝑐1, … , 𝑐𝑁−1)なことは分かったので、次は 「記憶のアップデート」= (𝑐0, 𝑐1, … , 𝑐𝑁−1)の時間発展 がどうなっているか導出したい。 実は驚くべき結論が成り立つ。 Theorem ([3, Appendix C]) 古典的な直交関数系に対して、𝑐(𝑡)の時間発展はlinear ODEで記述できる: 𝑐′ 𝑡 = 𝐴 𝑡 𝑐 𝑡 + 𝐵 𝑡 𝑓 𝑡 , (∃𝐴 𝑡 ∈ ℝ𝑁∗𝑁, ∃𝐵 𝑡 ∈ ℝ𝑁∗1)
  33. 33. Copyright © 2022 Morpho, Inc. All Rights Reserved 32 HiPPO 冒頭で線形方程式が出てきた理由はまさにこれ。 以下これを証明し、具体例を与える。 Notations • 𝑓, 𝜇(𝑡), 𝑃𝑛 (𝑡) :入力信号、時刻tでの測度、付随する直交多項式 • 𝑑𝜇(𝑡) = 𝜔(𝑡) 𝑥 𝑑𝑥、また𝜇(𝑡)は確率測度であると仮定(i.e.‫׬‬ 𝑑𝜇(𝑡) = 1) • 𝑝𝑛 (𝑡) 𝑥 = 𝑃𝑛 𝑡 (𝑥)/⟨𝑃𝑛 𝑡 , 𝑃𝑛 𝑡 ⟩(正規化)
  34. 34. Copyright © 2022 Morpho, Inc. All Rights Reserved 33 HiPPO (証明) まず係数𝑐𝑛(𝑡)の構成を思い出せば 𝑐𝑛 𝑡 = 𝑓≤𝑡, 𝑃𝑛 (𝑡) / 𝑃𝑛 (𝑡) , 𝑃𝑛 (𝑡) = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑡 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛 ′ (𝑡) = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥
  35. 35. Copyright © 2022 Morpho, Inc. All Rights Reserved 34 HiPPO (証明) まず係数𝑐𝑛(𝑡)の構成を思い出せば 𝑐𝑛 𝑡 = 𝑓≤𝑡, 𝑃𝑛 (𝑡) / 𝑃𝑛 (𝑡) , 𝑃𝑛 (𝑡) = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑡 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛 ′ (𝑡) = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 𝜕 𝜕𝑡 𝑝𝑛 𝑡 は𝑥についての𝑛次多項式 であるから𝑝0 (𝑡) , … , 𝑝𝑛 (𝑡) の線形和。
  36. 36. Copyright © 2022 Morpho, Inc. All Rights Reserved 35 HiPPO (証明) まず係数𝑐𝑛(𝑡)の構成を思い出せば 𝑐𝑛 𝑡 = 𝑓≤𝑡, 𝑃𝑛 (𝑡) / 𝑃𝑛 (𝑡) , 𝑃𝑛 (𝑡) = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑡 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛 ′ (𝑡) = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 𝜕 𝜕𝑡 𝑝𝑛 𝑡 は𝑥についての𝑛次多項式 であるから𝑝0 (𝑡) , … , 𝑝𝑛 (𝑡) の線形和。 𝜕 𝜕𝑡 𝜔(𝑡)は古典的な直交関数系では 𝜔(𝑡) とディラック𝛿𝑡の線形和。
  37. 37. Copyright © 2022 Morpho, Inc. All Rights Reserved 36 HiPPO (証明) まず係数𝑐𝑛(𝑡)の構成を思い出せば 𝑐𝑛 𝑡 = 𝑓≤𝑡, 𝑃𝑛 (𝑡) / 𝑃𝑛 (𝑡) , 𝑃𝑛 (𝑡) = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑡 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛 ′ (𝑡) = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 𝜕 𝜕𝑡 𝜔(𝑡)は古典的な直交関数系では 𝜔(𝑡) とディラック𝛿𝑡の線形和。 これより第一項は𝑐0, … , 𝑐𝑛の線形和、第二項は𝑐𝑛と𝑓(𝑡)の線形和。(証明終) 𝜕 𝜕𝑡 𝑝𝑛 𝑡 は𝑥についての𝑛次多項式 であるから𝑝0 (𝑡) , … , 𝑝𝑛 (𝑡) の線形和。
  38. 38. Copyright © 2022 Morpho, Inc. All Rights Reserved 37 HiPPO 実際にルジャンドル関数系を用いて実証してみる。 ただしルジャンドル関数系は[−1,1]上の関数系なので、 存在域を𝑡依存になるようスケールしてから適用する。 パターン1: [𝑡 − 𝜃, 𝑡]上に定義(𝜃 ≥ 0は何時刻前までを見るかを表すハイパラ) パターン2: [0, 𝑡]上に定義(過去の履歴をすべて見る) それぞれの場合で𝑐𝑛 ′ (𝑡)がどう書けるか見てみよう。 (図2つは[3]より引用)
  39. 39. Copyright © 2022 Morpho, Inc. All Rights Reserved 38 HiPPO パターン1: [𝑡 − 𝜃, 𝑡]上に定義 このときの正規直交関数系は、ルジャンドル関数系 𝑃𝑛 𝑥 を用いて 𝑝𝑛 𝑡 𝑥 ≔ 2𝑛 + 1 1 2𝑃𝑛 2 𝑥 − 𝑡 𝜃 + 1 𝜕 𝜕𝑡 𝑝𝑛 𝑡 = − 2𝑛 + 1 1 2 2 𝜃 2𝑛 − 1 1/2𝑝𝑛−1 (𝑡) + 2𝑛 − 5 1/2𝑝𝑛−3 (𝑡) + ⋯ またこのとき 𝜔 𝑡 𝑥 = 1 𝜃 1 𝑡−𝜃,𝑡 = 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 − 𝜃 − 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 𝜕 𝜕𝑡 𝜔 𝑡 = −𝛿 𝑥 − 𝑡 − 𝜃 + 𝛿 𝑥 − 𝑡 = 𝛿𝑡 − 𝛿𝑡−𝜃 (図は[3]より引用)
  40. 40. Copyright © 2022 Morpho, Inc. All Rights Reserved 39 HiPPO パターン1: [𝑡 − 𝜃, 𝑡]上に定義 これを先ほどの式に代入すると 第一項 = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 = − 2𝑛 + 1 1 2 2 𝜃 2𝑛 − 1 1 2𝑐𝑛−1 + 2𝑛 − 5 1 2𝑐𝑛−3 + ⋯ 第二項 = න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 = 𝑓 𝑡 𝑝𝑛 𝑡 𝑡 − 𝑓 𝑡 − 𝜃 𝑝𝑛 𝑡 𝑡 − 𝜃 辺々加えてこねこねすれば、 𝑐′ 𝑡 = − 1 𝜃 𝐴𝑐 𝑡 + 1 𝜃 𝐵𝑓 𝑡 𝐴𝑛𝑘 = 2𝑛 + 1 1 2 2𝑘 + 1 1 2 ൝ 1 𝑖𝑓 𝑘 ≤ 𝑛 −1 𝑛−𝑘 𝑖𝑓 𝑘 ≥ 𝑛 , 𝐵𝑛 = 2𝑛 + 1 1 2 (図は[3]より引用)
  41. 41. Copyright © 2022 Morpho, Inc. All Rights Reserved 40 HiPPO パターン1: [𝑡 − 𝜃, 𝑡]上に定義 Def. 測度 1 𝜃 1 𝑡−𝜃,𝑡 から導出されるHiPPOの時間発展式 𝑐′ 𝑡 = − 1 𝜃 𝐴𝑐 𝑡 + 1 𝜃 𝐵𝑓 𝑡 ただし 𝐴𝑛𝑘 = 2𝑛 + 1 1 2 2𝑘 + 1 1 2 ൝ 1 𝑖𝑓 𝑘 ≤ 𝑛 −1 𝑛−𝑘 𝑖𝑓 𝑘 ≥ 𝑛 , 𝐵𝑛 = 2𝑛 + 1 1 2 をHiPPO-LegTと呼ぶ。(translated Legendre) 実はこのODEは少し式変形すると[4]の論文で提案された式と一致する。 しかし[4]の論文ではPadé approximationという別手法を用いて導出。 (図は[3]より引用)
  42. 42. Copyright © 2022 Morpho, Inc. All Rights Reserved 41 HiPPO パターン2: [0, 𝑡]上に定義 このときの正規直交関数系は、ルジャンドル関数系 𝑃𝑛 𝑥 を用いて 𝑝𝑛 𝑡 𝑥 ≔ 2𝑛 + 1 1/2𝑃 𝑛 2𝑥 𝑡 − 1 𝜕 𝜕𝑡 𝑝𝑛 𝑡 = − 2𝑛 + 1 1/2 1 𝑡 𝑛 2𝑛 + 1 −1/2𝑝𝑛 (𝑡) + 2𝑛 − 1 1/2𝑝𝑛−1 (𝑡) + 2𝑛 − 3 1/2𝑝𝑛−2 (𝑡) + ⋯ またこのとき 𝜔 𝑡 𝑥 = 1 𝑡 1 0,𝑡 = 1 𝑡 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 𝜕 𝜕𝑡 𝜔 𝑡 = − 1 𝑡2 1 0,𝑡 + 1 𝑡 𝛿 𝑥 − 𝑡 = 1 𝑡 (−𝜔(𝑡) + 𝛿𝑡) (図は[3]より引用)
  43. 43. Copyright © 2022 Morpho, Inc. All Rights Reserved 42 HiPPO パターン2: [0, 𝑡]上に定義 これを先ほどの式に代入すると 第一項 = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 = − 2𝑛 + 1 1 2 1 𝑡 𝑛 2𝑛 + 1 − 1 2𝑐𝑛 + 2𝑛 − 1 1/2𝑐𝑛−1 + 2𝑛 − 3 1/2𝑐𝑛−2 + ⋯ 第二項 = න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 = − 1 𝑡 𝑐𝑛 𝑡 + 𝑓 𝑡 𝑝𝑛 𝑡 (𝑡) 辺々加えてこねこねすれば、 𝑐′ 𝑡 = − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 𝐴𝑛𝑘 = ൞ 2𝑛 + 1 1 2 2𝑘 + 1 1 2 (𝑛 > 𝑘) 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) , 𝐵𝑛 = 2𝑛 + 1 1 2 (図は[3]より引用)
  44. 44. Copyright © 2022 Morpho, Inc. All Rights Reserved 43 HiPPO パターン2: [0, 𝑡]上に定義 Def. 測度 1 𝑡 1 0,𝑡 から導出されるHiPPOの時間発展式 𝑐′ 𝑡 = − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 ただし 𝐴𝑛𝑘 = ൞ 2𝑛 + 1 1 2 2𝑘 + 1 1 2 (𝑛 > 𝑘) 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) , 𝐵𝑛 = 2𝑛 + 1 1 2 をHiPPO-LegSと呼ぶ。(scaled Legendre) 実はこれがまさに本論文で提案する新手法に他ならない! (図は[3]より引用)
  45. 45. Copyright © 2022 Morpho, Inc. All Rights Reserved 44 HiPPO ラゲール、チェビシェフ、エルミート等、他の直交関数系に対しても導出が可能。 HiPPO-LegSは[0,t]のすべての時刻を見る点で直感的にHiPPO-LegTよりも優 れているが、以降でこの式が多くの嬉しい性質を満たすことを見る。 ここまでのまとめ • ℎ𝑖𝑝𝑝𝑜の出力(𝑐0, 𝑐1, … , 𝑐𝑁−1)の時間変化はlinear ODEで書ける。 • ODEの係数行列は陽に書けて実際に計算可能。 • HiPPOの枠組みで既存手法を導出可能(HiPPO-LegT)。 • HiPPO-LegSという新しい時間発展式を提案。
  46. 46. Copyright © 2022 Morpho, Inc. All Rights Reserved 45 HiPPO 最後にHiPPO-LegSの持つ良い性質を見ていこう。 スペースの都合上、ここでは時間スケールに依存しないことだけ見る。 その他の性質は最後に結果のみを列挙する。
  47. 47. Copyright © 2022 Morpho, Inc. All Rights Reserved 46 HiPPO (証明) 前述のODEを計算するにあたり、まずは離散化をしないといけない。 𝑐′ 𝑡 = − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 の両辺を積分して、 𝑐 𝑡 + Δ𝑡 − 𝑐 𝑡 = න 𝑡 𝑡+Δ𝑡 − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 𝑑𝑡 ≈ Δ𝑡 2 − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 + − 1 𝑡 + Δ𝑡 𝐴𝑐 𝑡 + Δ𝑡 + 1 𝑡 + Δ𝑡 𝐵𝑓 𝑡 + Δ𝑡 Lemma ([3, Appendix B]) HiPPO-LegSは時間スケールに依存しない。
  48. 48. Copyright © 2022 Morpho, Inc. All Rights Reserved 47 HiPPO 辺々整理すると 𝐼 + Δ𝑡 2 𝑡 + Δ𝑡 𝐴 𝑐 𝑡 + Δ𝑡 = 𝐼 − Δ𝑡 2𝑡 𝐴 𝑐 𝑡 + Δ𝑡 2 𝑡 + Δ𝑡 + Δ𝑡 2𝑡 𝐵𝑓(𝑡) なお𝑓 𝑡 + Δ𝑡 = 𝑓(𝑡)の仮定を暗黙に使った。 ここで𝑡 = 𝑘Δ𝑡, 𝑐𝑘 ≔ 𝑐 𝑘Δ𝑡 , 𝑓𝑘 ≔ 𝑓(𝑘Δ𝑡)とすれば、 𝐼 + 1 2(𝑘 + 1) 𝐴 𝑐𝑘+1 = 𝐼 − 1 2𝑘 𝐴 𝑐𝑘 + 1 2 𝑘 + 1 + 1 2𝑘 𝐵𝑓(𝑡) ⇒どこにもΔ𝑡が出てこない!(証明終わり) (HiPPO-LegTなど他の直交関数系だとこうはならない)
  49. 49. Copyright © 2022 Morpho, Inc. All Rights Reserved 48 HiPPO 上の結果と合わせて、他の性質もまとめて理論説明を終わる。 ここまでのまとめ • HiPPO-LegSは時間スケールに依存しない。(ドメインシフトに強い) • HiPPOの1回のdiscretized ODE計算はO(N)。 • 𝑘 ∈ 𝑁: fixedおよび∀𝑙 > 𝑘に対して 𝜕𝑐𝑙+1 𝜕𝑓𝑘 = 𝑂(1/𝑙) (勾配消失・爆発しない!) • 𝑓𝑥≤𝑡の𝑆𝑝𝑎𝑛⟨𝑃0, … , 𝑃𝑁−1⟩への射影を𝑔(𝑡)としたとき • 𝑓が𝐿-Lipschitzなら 𝑓𝑥≤𝑡 − 𝑔(𝑡) = 𝑂 𝑡𝐿 𝑁 • 𝑓の𝑘回微分が有界なら 𝑓𝑥≤𝑡 − 𝑔(𝑡) = 𝑂 𝑡𝑘𝑁−𝑘+1/2
  50. 50. Copyright © 2022 Morpho, Inc. All Rights Reserved 49 HiPPO 実験 HiPPOの離散漸化式をRNNに組み込んで性能評価してみる。 hidden state ℎ𝑡の履歴を記憶させるよう下図のモデル設計を採用。 ([3]より引用)
  51. 51. Copyright © 2022 Morpho, Inc. All Rights Reserved 50 HiPPO 実験 タスク1: Permuted MNIST ([3]より引用)
  52. 52. Copyright © 2022 Morpho, Inc. All Rights Reserved 51 HiPPO 実験 タスク2: Character Trajectory Classification ペン先の3次元速度情報から書いている文字を当てるタスク。 サンプリングレートを変えてドメインシフトを再現しているが、 HiPPO-LegSは影響を受けていない。 ([3]より引用)
  53. 53. Copyright © 2022 Morpho, Inc. All Rights Reserved 52 HiPPO 実験 タスク3: Copying ([3]より引用)
  54. 54. Copyright © 2022 Morpho, Inc. All Rights Reserved 53 LSSL Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (NeurIPS 2021)
  55. 55. Copyright © 2022 Morpho, Inc. All Rights Reserved 54 LSSL 概要 HiPPOを以下のように改良する。 • ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) の形に増強。 • 𝐴, 𝐵, 𝐶, 𝐷を学習パラメータに変更。 • 上記連立方程式がCNN/RNNの要素を含むことを証明。
  56. 56. Copyright © 2022 Morpho, Inc. All Rights Reserved 55 LSSL 手法 動機は論文に書いてないが、HiPPOの式を ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) の形に増強する。 (状態空間モデルの方程式を意識してると思われる) しれっとA,B,C,Dは時間依存しないことになってる? HiPPO-LegSは係数行列は時間依存してたが。。。。。。 t → ∞ でAはほぼ変化しないので定数とみなしてるのかも。
  57. 57. Copyright © 2022 Morpho, Inc. All Rights Reserved 56 LSSL 手法 𝑦 𝑡 が𝑥 𝑡 と𝑢(𝑡)の線形和であることに注目する。 ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) これをLSSL(Linear State Space Layer)と呼ぶ。
  58. 58. Copyright © 2022 Morpho, Inc. All Rights Reserved 57 LSSL 手法 𝑦 𝑡 が𝑥 𝑡 と𝑢(𝑡)の線形和であることに注目する。 ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) これをLSSL(Linear State Space Layer)と呼ぶ。 以下この方程式が次の性質を持つことを順にみていこう。 1.線形であるために、RNNより高速に計算可能。 2.線形だと貧弱な気がするが、実は十分な表現力を持つ。
  59. 59. Copyright © 2022 Morpho, Inc. All Rights Reserved 58 LSSL 1.高速に計算可能 まずLSSLをbilinear離散化すると、特に第1式について積分して 𝑥’ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 𝑥 𝑡 + Δ𝑡 − 𝑥 𝑡 = Δ𝑡 2 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 + 𝐴𝑥 𝑡 + Δ𝑡 + 𝐵𝑢 𝑡 + Δ𝑡 𝑥 𝑡 + Δ𝑡 = ҧ 𝐴𝑥 𝑡 + ത 𝐵𝑢 𝑡 ただし ҧ 𝐴 = 𝐼 − Δ 2 𝐴 −1 𝐼 + Δ 2 𝐴 , ത 𝐵 = 𝐼 − Δ 2 𝐴 −1 Δ𝐵
  60. 60. Copyright © 2022 Morpho, Inc. All Rights Reserved 59 LSSL 1.高速に計算可能 この離散化式 ൝ 𝑥𝑘 = ҧ 𝐴𝑥𝑘−1 + ത 𝐵𝑢𝑘 𝑦𝑘 = ҧ 𝐶𝑥𝑘 + ഥ 𝐷𝑢𝑘 から𝑥を削除すると、𝑥−1 = 0として 𝑦0 = ҧ 𝐶 ത 𝐵𝑢0 + ഥ 𝐷𝑢0 𝑦1 = ҧ 𝐶 ҧ 𝐴 ത 𝐵𝑢0 + ത 𝐵𝑢1 + ഥ 𝐷𝑢1 𝑦2 = ҧ 𝐶 ҧ 𝐴 ҧ 𝐴 ത 𝐵𝑢0 + ത 𝐵𝑢1 + ത 𝐵𝑢2 + ഥ 𝐷𝑢2 … … … 𝑦𝑘 = ҧ 𝐶 ҧ 𝐴 𝑘 ത 𝐵𝑢0 + ҧ 𝐶 ҧ 𝐴 𝑘−1 ത 𝐵𝑢1 + ⋯ + ҧ 𝐶 ത 𝐵𝑢𝑘 + ഥ 𝐷𝑢𝑘
  61. 61. Copyright © 2022 Morpho, Inc. All Rights Reserved 60 LSSL 1.高速に計算可能 ഥ 𝐷はお尻にしか付かないのでഥ 𝐷 = 0として無視しよう。すると 𝑦𝑘 = ҧ 𝐶 ҧ 𝐴 𝑘 ത 𝐵𝑢0 + ҧ 𝐶 ҧ 𝐴 𝑘−1 ത 𝐵𝑢1 + ⋯ + ҧ 𝐶 ത 𝐵𝑢𝑘 となり、この式はまさに𝑦 = 𝐾𝐿( ҧ 𝐴, ത 𝐵, ҧ 𝐶) ∗ 𝑢のconvolutionに他ならない。 𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶 ≔ ( ҧ 𝐶 ത 𝐵, ҧ 𝐶 ҧ 𝐴 ത 𝐵, … , ҧ 𝐶 ҧ 𝐴𝐿−1 ത 𝐵) ここで𝐿はシーケンス長を表す。 これよりrecurrenceが不要になり、計算は高速。
  62. 62. Copyright © 2022 Morpho, Inc. All Rights Reserved 61 LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.1]) LSSLはbackward-Eulerで離散化した場合、 RNNのgating mechanismを包含する。 (証明) LSSLの第一式 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) をbackward-Eulerで離散化すると 𝑥 𝑡 + Δ𝑡 − 𝑥 𝑡 = න 𝑡 𝑡+Δ𝑡 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 𝑑𝑡 ≈ Δ𝑡 𝐴𝑥 𝑡 + Δ𝑡 + 𝐵𝑢(𝑡 + Δ𝑡)
  63. 63. Copyright © 2022 Morpho, Inc. All Rights Reserved 62 LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.1]) LSSLはbackward-Eulerで離散化した場合、 RNNのgating mechanismを包含する。 𝑥𝑘 ≔ 𝑥 𝑡 , 𝑥𝑘+1 ≔ 𝑥 𝑡 + Δ𝑡 , 𝑢𝑘+1 ≔ 𝑢(𝑡 + Δ𝑡)とし、さらにΔ𝑡 = 𝑒𝑧とおけば、 𝑥𝑘+1 − 𝑥𝑘 ≈ 𝑒𝑧 𝐴𝑥𝑘+1 + 𝐵𝑢𝑘+1 𝑥𝑘+1 ≈ 1 − 𝐴𝑒𝑧 1 + 𝑒𝑧 𝑥𝑘 + 𝐵𝑒𝑧 1 + 𝑒𝑧 𝑢𝑘 ここで𝐴 = 𝐵 = 1とすれば、 𝑥𝑘+1 ≈ 1 − 𝜎 𝑧 𝑥𝑘 + 𝜎 𝑧 𝑢𝑘となり、 これはgating mechanismに他ならない。(証明終わり)
  64. 64. Copyright © 2022 Morpho, Inc. All Rights Reserved 63 LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.2]) 𝑓(𝑡, 𝑥)がxについて局所Lipstizsである非線形関数としたとき、 無限にLSSLをstackしたモデルは𝑥’ 𝑡 = −𝑥 𝑡 + 𝑓(𝑡, 𝑥(𝑡))を解ける。 (証明概略) LSSLの線形部分をstackすると、それが実質ピカールの逐次近似 法を回していることになっている。 非線形部分𝑓はLSSL間にpointwise non-linearityな層を挟むことで再現す る。(証明終わり) ※この命題は本筋には使わない。詳細は各自論文参照。
  65. 65. Copyright © 2022 Morpho, Inc. All Rights Reserved 64 LSSL ここまでのまとめ • HiPPOにさらに線形方程式を追加したLSSLを提案。 • LSSLはconvolutionとして解釈可能なため高速。 • LSSLはRNNを含み、non-linear ODEを解くだけの能力を持つ。
  66. 66. Copyright © 2022 Morpho, Inc. All Rights Reserved 65 LSSL LSSLがHiPPOより真に優位であることは分かった。 次にこれを実際にどう学習に組み込むかを見ていく。 特に • Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい。 • Convolution:𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶 ≔ ( ҧ 𝐶 ത 𝐵, ҧ 𝐶 ҧ 𝐴 ത 𝐵, … , ҧ 𝐶 ҧ 𝐴𝐿−1 ത 𝐵)を如何に高速計算するか。 を調べたい。
  67. 67. Copyright © 2022 Morpho, Inc. All Rights Reserved 66 LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? Convolution:𝑲𝑳 ഥ 𝑨, ഥ 𝑩, ഥ 𝑪 ≔ (ഥ 𝑪ഥ 𝑩, ഥ 𝑪ഥ 𝑨ഥ 𝑩, … , ഥ 𝑪ഥ 𝑨𝑳−𝟏 ഥ 𝑩)を如何に高速計算するか この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁3 𝐿)かかる。 もっと速く計算できないだろうか?
  68. 68. Copyright © 2022 Morpho, Inc. All Rights Reserved 67 LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? Convolution:𝑲𝑳 𝑨, 𝑩, 𝑪 ≔ (𝑪𝑩, 𝑪𝑨𝑩, … , 𝑪𝑨𝑳−𝟏𝑩)を如何に高速計算するか この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁3 𝐿)かかる。 もっと速く計算できないだろうか? ここで残念なお知らせ LSSLの論文でこの考察をしているが、 その結果はお世辞にもきれいとは言えない。 しかも計算は非常に不安定。
  69. 69. Copyright © 2022 Morpho, Inc. All Rights Reserved 68 LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? Convolution:𝑲𝑳 𝑨, 𝑩, 𝑪 ≔ (𝑪𝑩, 𝑪𝑨𝑩, … , 𝑪𝑨𝑳−𝟏𝑩)を如何に高速計算するか この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁3 𝐿)かかる。 もっと速く計算できないだろうか? これらの問題点は S4の論文にて 1年越しに解決!
  70. 70. Copyright © 2022 Morpho, Inc. All Rights Reserved 69 S4 Efficiently Modeling Long Sequences with Structured State Spaces (ICLR 2022 Oral)
  71. 71. Copyright © 2022 Morpho, Inc. All Rights Reserved 70 S4 概要 LSSLで消化不良だった • Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい。 • Convolution:𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶 ≔ ( ҧ 𝐶 ത 𝐵, ҧ 𝐶 ҧ 𝐴 ത 𝐵, … , ҧ 𝐶 ҧ 𝐴𝐿−1 ത 𝐵)を如何に高速計算するか。 を解決する。
  72. 72. Copyright © 2022 Morpho, Inc. All Rights Reserved 71 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい 一般的なHiPPO行列の形を導出するのは難しい。(LSSLの論文ではそれをやって大変汚いことに) そこで 「計算しやすさ」 と 「HiPPO-LegT/LegSを含む」 ことを条件に、学習する行列Aのクラスを決める。
  73. 73. Copyright © 2022 Morpho, Inc. All Rights Reserved 72 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Def. 行列𝐴 ∈ 𝑅𝑛∗𝑛が 𝐴 = 𝐹 − 𝑝𝑞𝑇 (𝐹: 𝑛𝑜𝑟𝑚𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗𝑘 𝑘 ≪ 𝑛 ) と書けるとき、𝐴はNPLR(Normal Plus Low-Rank)表現を持つという。 (Plusと言いつつマイナスにしているのは、本スライドでの説明の都合による) Fact 以下は同値: 1. 𝐹はnormal (i.e. 𝐹𝐹∗ = 𝐹∗𝐹) 2. 𝐹はユニタリ行列で対角化可能
  74. 74. Copyright © 2022 Morpho, Inc. All Rights Reserved 73 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 (証明) 以下ではHiPPO LegSのみ見ていく。このとき行列𝐴は 𝐴𝑛𝑘 = − ൞ 2𝑛 + 1 1 2 2𝑘 + 1 1 2 (𝑛 > 𝑘) 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) と書けた。
  75. 75. Copyright © 2022 Morpho, Inc. All Rights Reserved 74 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 ここで𝑝 = 2𝑛+1 2 1 2 𝑛 とすると、 𝑝𝑝𝑇 𝑛𝑘 = 1 2 2𝑛 + 1 1 2 2𝑘 + 1 1 2 であり、 𝐴 + 𝑝𝑝𝑇 𝑛𝑘 = − 1 2 2𝑛 + 1 1 2 2𝑘 + 1 1 2 (𝑛 > 𝑘) ∗∗∗ 略 ∗∗∗ (𝑛 = 𝑘) − 1 2 2𝑛 + 1 1 2 2𝑘 + 1 1 2(𝑛 < 𝑘)
  76. 76. Copyright © 2022 Morpho, Inc. All Rights Reserved 75 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 すなわち 𝐴 + 𝑝𝑝𝑇 = 𝑠𝑘𝑒𝑤_𝑠𝑦𝑚𝑚𝑒𝑡𝑟𝑖𝑐 + 𝑘𝐼, ∃𝑘 ∈ ℝ の形になっており、特に右辺は正規行列。(証明終わり)
  77. 77. Copyright © 2022 Morpho, Inc. All Rights Reserved 76 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい これを踏まえて行列AはNPLRの中で学習させることを考える。 が、実はさらにクラスを制限しても問題ないことを次に示す。 Def. 行列𝐴 ∈ 𝑅𝑛∗𝑛が 𝐴 = Λ − 𝑝𝑞𝑇 (Λ: 𝑑𝑖𝑎𝑔𝑜𝑛𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗𝑘 𝑘 ≪ 𝑛 ) と書けるとき、𝐴はDPLR(Diagonal Plus Low-Rank)表現を持つという。
  78. 78. Copyright © 2022 Morpho, Inc. All Rights Reserved 77 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Lemma 3.1]) HiPPO行列に共役な作用を施しても出力は不変。 (証明) 主張がやや不明瞭だが、証明を見れば意味が分かる。 ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) に対して 𝐴, 𝐵, 𝐶, 𝐷 → (𝑉−1𝐴𝑉, 𝑉−1𝐵, 𝐶𝑉, 𝐷)の変換を施すと、 ൝ 𝑥′ 𝑡 = 𝑉−1𝐴𝑉𝑥 𝑡 + 𝑉−1𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑉𝑥 𝑡 + 𝐷𝑢(𝑡) ↔ ቊ 𝑉𝑥′ 𝑡 = 𝐴𝑉𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑉𝑥 𝑡 + 𝐷𝑢(𝑡)
  79. 79. Copyright © 2022 Morpho, Inc. All Rights Reserved 78 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Lemma 3.1]) HiPPO行列に共役な作用を施しても出力は不変。 すなわち𝑉の共役な作用が𝐴にかかっても、𝐵, 𝐶を適切に変換すれば、 作用の影響は潜在変数の変数変換にとどまる。(証明終わり) これより行列𝐴をDPLRの中で学習させるとしても問題ない。
  80. 80. Copyright © 2022 Morpho, Inc. All Rights Reserved 79 S4 ここまでのまとめ 𝐴 = Λ − 𝑝𝑞𝑇 として、Λ, 𝑝, 𝑞を学習させることにする。 これにより求まる𝐴の属する空間は、 古典的な直交関数形に対するHiPPO行列たちを含む。
  81. 81. Copyright © 2022 Morpho, Inc. All Rights Reserved 80 S4 𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶 ≔ ( ҧ 𝐶 ത 𝐵, ҧ 𝐶 ҧ 𝐴 ത 𝐵, … , ҧ 𝐶 ҧ 𝐴𝐿−1 ത 𝐵)の高速計算 ここが本論文の山場。 なんと上記のconvolutionカーネル計算を、 愚直計算の𝑂(𝑁3𝐿)からなんと ෨ 𝑂(𝑁 + 𝐿)にまで落としてしまう。 超絶技巧が盛りだくさんなので、step-by-stepに追っていこう。
  82. 82. Copyright © 2022 Morpho, Inc. All Rights Reserved 81 S4 STEP0. 先ほど述べたように、 𝐴 = Λ − 𝑝𝑞𝑇 (Λ: 𝑑𝑖𝑎𝑔𝑜𝑛𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗1) と置く。説明簡単化のため、𝑝, 𝑞は𝑛 ∗ 1行列とする。(HiPPO-LegSはそう) また統一性のため ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶∗𝑥 𝑡 + 𝐷𝑢(𝑡) のように𝐶を転置して、𝐶が𝐵, 𝑝, 𝑞と同じℝ𝑛∗1の元であるようにする。
  83. 83. Copyright © 2022 Morpho, Inc. All Rights Reserved 82 S4 STEP1. 𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ ≔ ( ҧ 𝐶∗ ത 𝐵, ҧ 𝐶∗ ҧ 𝐴 ത 𝐵, … , ҧ 𝐶∗ ҧ 𝐴𝐿−1 ത 𝐵) を直接求めるのではなく、それのz変換もどき ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ ≔ ෍ 𝑖=0 𝐿−1 ҧ 𝐶∗ ҧ 𝐴𝑖 ത 𝐵𝑧𝑖 ∈ ℂ[𝑧] を求めることを考える。 ෡ 𝐾𝐿から𝐾𝐿を導出するのは、zに1のべき根を突っ込んでiFFTにより𝑂(𝐿 log 𝐿)
  84. 84. Copyright © 2022 Morpho, Inc. All Rights Reserved 83 S4 STEP2. Lemma ([6, Lemma C.3]) ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ = 2 1 + 𝑧 ሚ 𝐶∗𝑅 𝑧 𝐵 − ሚ 𝐶∗𝑅 𝑧 𝑝 1 + 𝑞∗𝑅 𝑧 𝑝 −1𝑞∗𝑅 𝑧 𝐵 ただし ሚ 𝐶 = 𝐶 𝐼 − ҧ 𝐴𝐿 , 𝑅 𝑧; Λ = 2 Δ 1 − 𝑧 1 + 𝑧 − Λ −1 (証明) 形式べき級数を用いて、mod 𝑧𝐿 の下で ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ ≔ ෍ 𝑖=0 𝐿−1 ҧ 𝐶∗ ҧ 𝐴𝑖 ത 𝐵𝑧𝑖 = ҧ 𝐶∗ 𝐼 − ҧ 𝐴𝐿 𝐼 − ҧ 𝐴𝑧 −1 ത 𝐵 = ሚ 𝐶∗ 𝐼 − ҧ 𝐴𝑧 −1 ത 𝐵
  85. 85. Copyright © 2022 Morpho, Inc. All Rights Reserved 84 S4 また、LSSLの離散化手続きを思い出すと ҧ 𝐴 = 𝐼 − Δ 2 𝐴 −1 𝐼 + Δ 2 𝐴 , ത 𝐵 = 𝐼 − Δ 2 𝐴 −1 Δ𝐵 であるが、これを前述の式に代入すると以下を得る(詳細は[6, Lemma C.4]) ሚ 𝐶∗ 𝐼 − ҧ 𝐴𝑧 −1 ത 𝐵 = 2Δ 1 + 𝑧 ሚ 𝐶∗ 2 1 − 𝑧 1 + 𝑧 𝐼 − Δ𝐴 −1 𝐵
  86. 86. Copyright © 2022 Morpho, Inc. All Rights Reserved 85 S4 ここでさらに𝐴 = Λ − 𝑝𝑞𝑇なことを思い出すと ሚ 𝐶∗ 𝐼 − ҧ 𝐴𝑧 −1 ത 𝐵 = 2Δ 1 + 𝑧 ሚ 𝐶∗ 2 1 − 𝑧 1 + 𝑧 𝐼 − Δ Λ − 𝑝𝑞∗ −1 𝐵 = 2 1 + 𝑧 ሚ 𝐶∗ 2 Δ 1 − 𝑧 1 + 𝑧 𝐼 − Λ + 𝑝𝑞∗ −1 𝐵 = 2 1 + 𝑧 ሚ 𝐶∗𝑅 𝑧 𝐵 − ሚ 𝐶∗𝑅 𝑧 𝑝 1 + 𝑞∗𝑅 𝑧 𝑝 −1 𝑞∗𝑅 𝑧 𝐵 なお最後の等号はWoodbury Identityから従う。(証明終わり) Fact (Woodbury Identity) 任意の行列𝐴, 𝑃, 𝑄に対して以下が成り立つ 𝐴 + 𝑈𝑉∗ −1 = 𝐴−1 − 𝐴−1𝑈 𝐼 + 𝑉∗𝐴−1𝑈 −1𝑉∗𝐴−1 diagonal
  87. 87. Copyright © 2022 Morpho, Inc. All Rights Reserved 86 S4 STEP2. 求めた式は一見煩雑になっただけに見えるが、よく見ると赤線部分 ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ = 2 1 + 𝑧 ሚ 𝐶∗ 𝑅 𝑧 𝐵 − ሚ 𝐶∗ 𝑅 𝑧 𝑝 1 + 𝑞∗ 𝑅 𝑧 𝑝 −1 𝑞∗ 𝑅 𝑧 𝐵 はすべてスカラーであり、かつ𝑅 𝑧; Λ = 2 Δ 1−𝑧 1+𝑧 − Λ −1 は対角行列。 すなわち上の計算は登場する行列たちが既知なら𝑂(𝑁)で求まる。
  88. 88. Copyright © 2022 Morpho, Inc. All Rights Reserved 87 S4 STEP3. よってあとは新規の登場人物たち、とくに ሚ 𝐶 = 𝐶 𝐼 − ҧ 𝐴𝐿 , 𝑅 𝑧; Λ = 2 Δ 1 − 𝑧 1 + 𝑧 − Λ −1 の2つが高速に求められれば良い。 前者は発想の転換で、𝐶ではなく ሚ 𝐶を最初から学習させることにすれば解決。
  89. 89. Copyright © 2022 Morpho, Inc. All Rights Reserved 88 S4 STEP3. 𝑅 𝑧; Λ = 2 Δ 1 − 𝑧 1 + 𝑧 − Λ −1 だが、一見対角行列なので𝑂(𝑁)で計算可能で、何も問題ないように見える。 しかしSTEP1を見直すと、我々は ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ をすべての1の𝐿乗根に対して求める必要がある。 よってこのままだと𝑂(𝑁𝐿)かかってしまいまずい。
  90. 90. Copyright © 2022 Morpho, Inc. All Rights Reserved 89 S4 STEP3. どのみち𝑅(𝑧)を𝑂(𝑁)より早く求めたとしても、STEP2の計算をすべての1の𝐿 乗根に対して行うと𝑂(𝑁𝐿)かかってしまう。 そこで少し視点を変えて、一般に赤線部分 𝑉∗𝑅 𝑧 𝑈, (∀𝑈, 𝑉 ∈ ℝ𝑛∗1) をすべての𝑧 ∈ {1の𝐿乗根}に対して一括で ෨ 𝑂(𝑁 + 𝐿)で求めることを考える。 実は𝑅(𝑧)の特殊構造により、これが可能である。
  91. 91. Copyright © 2022 Morpho, Inc. All Rights Reserved 90 S4 STEP3. Def. K ∈ ℝ𝑀∗𝑁であって、 𝐾𝑖𝑗 = 1 𝜔𝑖 − 𝜆𝑗 , (𝜔𝑖, 𝜆𝑗 ∈ ℂ) と書けるものをCauchy Kernelと呼ぶ。 Fact [7] Cauchy Kernelの行列ベクトル積にかかる計算量は ൞ 𝑂 𝑀 + 𝑁 log2 𝑀 + 𝑁 , 𝑒𝑥𝑎𝑐𝑡 𝑎𝑟𝑖𝑡ℎ𝑚𝑒𝑡𝑖𝑐 𝑂 𝑀 + 𝑁 log 𝑀 + 𝑁 log 1 𝜖 , 𝑛𝑢𝑚𝑒𝑟𝑖𝑐𝑎𝑙𝑙𝑦 𝑡𝑜 𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛 𝜖
  92. 92. Copyright © 2022 Morpho, Inc. All Rights Reserved 91 S4 STEP3. これを踏まえると 𝑅 𝑧; Λ = 2 Δ 1 − 𝑧 1 + 𝑧 − Λ −1 はまさにCauchy Kernelに他ならない。 ゆえにすべての𝑧 ∈ {1の𝐿乗根}に対して、 ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ の計算は一括で ෨ 𝑂(𝑁 + 𝐿)で終わる。 STEP1のiFFTは𝑂(𝐿 log 𝐿)なので、全体としても ෨ 𝑂(𝑁 + 𝐿)で計算が完了する。 (おわり!)
  93. 93. 92 Copyright © 2022 Morpho, Inc. All Rights Reserved. • 時系列モデリングの新手法HiPPOを提案。 • HiPPOを状態空間モデルの方程式に組み込み、高速なconvolution計算を実現。 • Path-Xタスクで世界初の推論成功を達成。 「所感」 • 概念基盤がかなりしっかりしていて、かつ汎用性が高い。 • 後続研究にS4をaudio generationやvideo classificationに使用した例あり。 「おまけ」 公式実装:https://github.com/HazyResearch/state-spaces 解説付きJax実装:https://srush.github.io/annotated-s4 まとめ まとめ
  94. 94. 93 Copyright © 2022 Morpho, Inc. All Rights Reserved. [1] Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. Long range arena : A benchmark for efficient transformers. In International Conference on Learning Representations, 2021. [2] 黒田成俊. 関数解析. 共立出版. 1980. [3] Albert Gu, Tri Dao, Stefano Ermon, Atri Rudra, and Christopher R´e. Hippo: Recurrent memory with optimal polynomial projections. In Advances in Neural Information Processing Systems, pages 1474-1487, 2020. [4] Aaron Voelker, Ivana Kajić, and Chris Eliasmith. Legendre memory units: Continuous- time representation in recurrent neural networks. In Advances in Neural Information Processing Systems, pages 15544–15553, 2019. [5] Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, and Christopher R´e. Combining recurrent, convolutional, and continuous-time models with the structured learnable linear state space layer. In Advances in Neural Information Processing Systems, pages 572-585, 2021. まとめ 参考文献
  95. 95. 94 Copyright © 2022 Morpho, Inc. All Rights Reserved. [6] Albert Gu, Karan Goel, and Christopher R´e. Efficiently modeling long sequences with structured state spaces. In International Conference on Learning Representations, 2022. [7] Victor Pan. Structured matrices and polynomials: unified superfast algorithms. Springer Science & Business Media, 2001. まとめ 参考文献
  96. 96. Thank you

Description

Morpho Tech Bolgより本資料に関する記事をご覧ください。
・「HiPPO/S4解説」:https://techblog.morphoinc.com/entry/2022/05/24/102648
・執筆者:CTO室リサーチャー 角田

・Morpho Tech Blog: https://techblog.morphoinc.com/
・Morpho, Inc.: https://www.morphoinc.com/

Transcript

  1. 1. HiPPO/S4解説 2022/04/22 LSTMやTransformerを超える 時系列モデリングの新手法 株式会社モルフォ 角田良太朗
  2. 2. Copyright © 2022 Morpho, Inc. All Rights Reserved 1 Overview Warning 本日の内容はゴリゴリの理論です。 中身は非常に美しいのですが、なにせ1時間しかないこともあり 結構な脱落者が発生する危険が高いです。 それを踏まえ、通常の流れには逆らい、まず結果から説明します。
  3. 3. Copyright © 2022 Morpho, Inc. All Rights Reserved 2 Overview Long-Range Arena[1]というタスクが存在する。 • Efficient Transformer系の長距離依存性を統一的に評価するためのベン チマーク。 • トークン数1000~16000の様々なテストデータを用意。 • タスクは6種類: • LONG LISTOPS • BYTE-LEVEL TEXT CLASSIFICATION • BYTE-LEVEL DOCUMENT RETRIEVAL • IMAGE CLASSIFICATION ON SEQUENCES OF PIXELS • PATHFINDER • PATHFINDER-X(←次ページで詳述)
  4. 4. Copyright © 2022 Morpho, Inc. All Rights Reserved 3 Overview PATHFINDER-X positive negative 128x128の画像中の2点が点線でつながっているか二値判定。 ただし、 画像はflattenして入力、 2D-Convの使用禁止。 ([6]より引用) ([6]より引用)
  5. 5. Copyright © 2022 Morpho, Inc. All Rights Reserved 4 Overview このタスクで2021年11月に大幅なSOTA更新あり。 ([6]より引用)
  6. 6. Copyright © 2022 Morpho, Inc. All Rights Reserved 5 Overview このタスクで2021年11月に大幅なSOTA更新あり。 S4以外のすべてのモデルは、推論に失敗していた(乱択と同程度) ([6]より引用)
  7. 7. Copyright © 2022 Morpho, Inc. All Rights Reserved 6 Overview 本スライドの目標は、この驚異的なモデルS4とは何者なのかを解明すること。 手法をざっと述べると。。。 ([6]より引用)
  8. 8. Copyright © 2022 Morpho, Inc. All Rights Reserved 7 Overview 1. 時系列解析を状態空間モデルとして次のように定式化 𝑥’(𝑡) = 𝐴𝑥(𝑡) + 𝐵𝑢(𝑡) 𝑦(𝑡) = 𝐶𝑥(𝑡) + 𝐷𝑢(𝑡) (𝑢 𝑡 ∈ ℝ𝐿:input, 𝑥 𝑡 ∈ ℝ𝑁∗𝐿: hidden state, 𝑦 𝑡 ∈ ℝ𝐿: output) A, B, C, D はlearned parameters AはHiPPO matrixとして初期化し、正規行列+low-rank matrixの形のみを 取るよう制限する。
  9. 9. Copyright © 2022 Morpho, Inc. All Rights Reserved 8 Overview 2. この式を離散化&展開することで 𝑦𝑘 = 𝐶𝐴𝑘𝐵𝑢0 + 𝐶𝐴𝑘−1𝐵𝑢1 + ⋯ + 𝐶𝐴𝐵𝑢𝑘−1 + 𝐶𝐵𝑢𝑘 𝑦 = 𝐾 ∗ 𝑢 (𝐾 ≔ 𝐶𝐵 + 𝐶𝐴𝐵 + ⋯ + 𝐶𝐴𝐿−1𝐵 ∈ ℝ𝐿) と1D-Convの形にかける。 𝐾が高速計算できれば学習はRNNより高速。 (𝐿回のiterationをせず一発で計算できるので)
  10. 10. Copyright © 2022 Morpho, Inc. All Rights Reserved 9 Overview 3.Kを直接計算せず、これのスペクトラム 𝐹 𝐾 ≔ ෍ 𝑗=0 𝐿−1 𝐾𝑗𝜁𝑗 を計算してiFFTで𝐾𝑗を一括導出したい。 𝐹(𝐾)を一般に𝐾の𝑧変換として求めることを考える。 これはAを前述の形に制限したことから、 Woodbury-Identityを用いて Cauchy Kernel4つの和として記述でき、特に ෨ 𝑂(𝑁 + 𝐿)で導出可能。 (おわり)
  11. 11. Copyright © 2022 Morpho, Inc. All Rights Reserved 10 Overview 何を言ってるのか全く分からないが、状態空間モデルの方程式を解くのに • 適切な係数空間の設計(HiPPO行列) • convolutionに書き直して高速な行列計算 をしたことで、超長距離依存性を保持できたことがポイントみたい。 実際次で見るように、これらの施策は精度にcriticalに効いている。
  12. 12. Copyright © 2022 Morpho, Inc. All Rights Reserved 11 Overview • 係数行列の初期化の影響(sequential CIFAR10で実験) ([6]より引用)
  13. 13. Copyright © 2022 Morpho, Inc. All Rights Reserved 12 Overview • 高速な行列計算アルゴリズムの影響 ([6]より引用)
  14. 14. Copyright © 2022 Morpho, Inc. All Rights Reserved 13 Overview 高度で緻密な理論設計がここまで見事に精度に反映されるような deep learning modelは見たことがない! これは解読する価値がありそうだ。。。!
  15. 15. Copyright © 2022 Morpho, Inc. All Rights Reserved 14 Overview しかし現段階ではいろいろと謎が多すぎる。 • 唐突に出てきた線形方程式は一体? • HiPPOって何? • 正規行列+low-rankでパラメータを書く動機は? • Woodbury-Identityどこで使うんや? • Cauchy kernelがなんで絡むん? • ・・・・・・
  16. 16. Copyright © 2022 Morpho, Inc. All Rights Reserved 15 Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral)
  17. 17. Copyright © 2022 Morpho, Inc. All Rights Reserved 16 Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral) First author全部同じ人や。。。 Albert Guさん強すぎ。。。
  18. 18. Copyright © 2022 Morpho, Inc. All Rights Reserved 17 Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral) First author全部同じ人や。。。 Albert Guさん強すぎ。。。 今日はこれらを解説していく!! S4を分かるには これ全部理解するしかない
  19. 19. Copyright © 2022 Morpho, Inc. All Rights Reserved 18 Overview Warning いよいよ本番です。 頑張って付いてきてください。 分からなくなったらすぐ声を上げてください。
  20. 20. Copyright © 2022 Morpho, Inc. All Rights Reserved 19 HiPPO HiPPO: Recurrent Memory with Optimal Polynomial Projections (NeurIPS 2020 Spotlight) 正直この論文が一番の肝
  21. 21. Copyright © 2022 Morpho, Inc. All Rights Reserved 20 HiPPO 概要 長時間の時系列データを扱う際はRNNやLSTMがよく用いられるが、 • 数万ステップにもなると記憶が抜けてしまう。 • シーケンス長や時間スケールへの暗黙の依存により、テスト時に汎化しない。 • 理論的な解釈が何となくしか与えられていない。 という問題がある。
  22. 22. Copyright © 2022 Morpho, Inc. All Rights Reserved 21 HiPPO 概要 そこで本論文では • 記憶に関する理論的な定式化を与え、既存手法たちをその枠組みで再解釈。 • シーケンス長や時間スケールに依存しない新手法を提案。 • 新手法の優位性を、提案した枠組みを用いて厳密に証明。 している。
  23. 23. Copyright © 2022 Morpho, Inc. All Rights Reserved 22 HiPPO 手法 そもそも「記憶を保持する」という言葉が既に曖昧。 そこでこれを次のように言い換えるところからスタートする。 「記憶」 = 「時間依存する測度」に基づく「入力信号の多項式近似」 1 2
  24. 24. Copyright © 2022 Morpho, Inc. All Rights Reserved 23 HiPPO ①入力信号の多項式近似 ℝ上の入力信号𝑓: ℝ → ℝを考えよう。 目標は 𝑓 𝑥 𝑥 ≤ 𝑡 までが与えられるので、これから𝑓(𝑥)を推定することである。 t 𝑥 𝑓(𝑥) 0
  25. 25. Copyright © 2022 Morpho, Inc. All Rights Reserved 24 HiPPO ①入力信号の多項式近似 {𝑓(𝑥)|𝑥 ≤ 𝑡}の値をすべて使って推定するのが最善だが、我々はこれをモデルに 学習させたいので、すべての値をモデルに記憶させるのはメモリ的に厳しい。 そこで{𝑓(𝑥)|𝑥 ≤ 𝑡}を何らかの低次元表現に保存することを考える。 𝑡 𝑥 𝑓(𝑥) 0
  26. 26. Copyright © 2022 Morpho, Inc. All Rights Reserved 25 HiPPO ①入力信号の多項式近似 ここでは「直交多項式系」を低次元表現として採用する。 Def. 測度𝜇(𝑥)に対する直交多項式系とは多項式の集合 𝑃𝑛 𝑛=0,1,2,… であって deg 𝑃𝑛 = 𝑛, ⟨𝑃𝑛, 𝑃𝑚⟩𝜇 ≔ න 𝑃𝑛 𝑥 𝑃𝑚 𝑥 𝑑𝜇 = 𝑎𝑛,𝑚𝛿𝑛,𝑚 (∃𝑎𝑛,𝑚 ∈ ℝ) を満たすものを言う。 Rem. ℝ上の測度が与えられれば、直交多項式はスケール𝑎𝑛,𝑚を除き一意的。 これは{1, 𝑥, 𝑥2, … }をグラムシュミット直交化すれば示せる。
  27. 27. Copyright © 2022 Morpho, Inc. All Rights Reserved 26 HiPPO ①入力信号の多項式近似 Ex. ルジャンドル多項式 𝑃𝑛 𝑥 ≔ 1 2𝑛𝑛! 𝑑𝑛 𝑑𝑥𝑛 [ 𝑥2 − 1 𝑛] は測度𝜇 𝑥 ≔ 1 −1,1 𝑥 に対する直交多項式。 直交性はググれば証明出てくる。 𝑃𝑛, 𝑃𝑚 = 2 2𝑛+1 𝛿𝑛,𝑚 なおルジャンドル多項式特有の性質として次がある(後で使う) • 𝑃𝑛 1 = 1, 𝑃𝑛 −1 = −1 𝑛 • 𝑃𝑛 ′ = 2𝑛 − 1 𝑃𝑛−1 + 2𝑛 − 3 𝑃𝑛−2 + ⋯ • 𝑥 + 1 𝑃𝑛 ′ 𝑥 = 𝑛𝑃𝑛 + 2𝑛 − 1 𝑃𝑛−1 + 2𝑛 − 3 𝑃𝑛−2 + ⋯
  28. 28. Copyright © 2022 Morpho, Inc. All Rights Reserved 27 HiPPO ①入力信号の多項式近似 この時𝑓[𝑥≤𝑡}を直交多項式 𝑃𝑛 𝑛=0,1,…,𝑁−1で近似することを考えよう。 直交多項式の次数は適当な𝑁未満で打ち切っていることに注意。 Fact ([2, Theorem3.10, Theorem3.5]) 1. 測度𝜇から定まる直交多項式系 𝑃𝑛 𝑛=0,1,…を固定した時、 任意の関数𝑓 ∈ 𝐿2(ℝ; 𝜇)は以下の級数展開を持つ。 𝑓 = ෍ 𝑛=0,1,… 𝑐𝑛𝑃𝑛 , 𝑐𝑛 ≔ 𝑓, 𝑃𝑛 / 𝑃𝑛, 𝑃𝑛 2. 上記級数を𝑛 = 𝑁 − 1で打ち切ったものを𝑓(𝑛) としたとき、 𝑓(𝑛) = 𝑎𝑟𝑔𝑚𝑖𝑛𝑔∈𝑆𝑝𝑎𝑛⟨𝑃0,…,𝑃𝑁−1⟩ 𝑓 − 𝑔 𝜇
  29. 29. Copyright © 2022 Morpho, Inc. All Rights Reserved 28 HiPPO ①入力信号の多項式近似 気持ちとしては、「𝑓をN次未満直交多項式の張る空間に射影」している。 これにより、話を元に戻すと 𝑓{𝑥≤𝑡} ≈ 𝑐0𝑃𝑜 + 𝑐1𝑃1 + ⋯ + 𝑐𝑁−1𝑃𝑁−1 として過去の信号履歴を(𝑐0, 𝑐1, … , 𝑐𝑁−1)の𝑁変数に圧縮することができた。 さらに上式を使って未来の信号を予測することも可能! 𝑓{𝑥≤𝑡} 𝑆𝑝𝑎𝑛⟨𝑃0, … , 𝑃𝑁−1⟩
  30. 30. Copyright © 2022 Morpho, Inc. All Rights Reserved 29 HiPPO ②時間依存する測度 ではどんな測度(直交多項式)を選ぶのが最適だろうか? • 入力信号𝑓は時間が進むごとに履歴がどんどん積み重なるので、 測度𝜇も時間発展させた方がよいだろう。 そこで以後、時刻𝑡における測度を𝜇 𝑡 と記述する。 注意: 特に以降𝜇 𝑡 のsupportは(−∞, 𝑡]に含まれるものとする。
  31. 31. Copyright © 2022 Morpho, Inc. All Rights Reserved 30 HiPPO ここまでのまとめ • 「記憶」=「入力信号𝑓の過去履歴を𝑁次元直交多項式系{𝑃𝑛}に射影」 • 直交多項式系 𝑃𝑛 は測度𝜇 𝑡 に応じて時間発展させる。 Def. 入力信号𝑓の履歴を直交多項式系に射影して、その係数を取得する操作を ℎ𝑖𝑝𝑝𝑜 𝑓{𝑥≤𝑡} ≔ 𝑐0, 𝑐1, … , 𝑐𝑁−1 𝑤ℎ𝑒𝑟𝑒 𝑓{𝑥≤𝑡} ≈ 𝑐0𝑃𝑜 + 𝑐1𝑃1 + ⋯ + 𝑐𝑁−1𝑃𝑁−1 と書き、HiPPO operatorと呼ぶ。(HiPPO=high-order Polynomial Projection Operator)
  32. 32. Copyright © 2022 Morpho, Inc. All Rights Reserved 31 HiPPO 「記憶」=(𝑐0, 𝑐1, … , 𝑐𝑁−1)なことは分かったので、次は 「記憶のアップデート」= (𝑐0, 𝑐1, … , 𝑐𝑁−1)の時間発展 がどうなっているか導出したい。 実は驚くべき結論が成り立つ。 Theorem ([3, Appendix C]) 古典的な直交関数系に対して、𝑐(𝑡)の時間発展はlinear ODEで記述できる: 𝑐′ 𝑡 = 𝐴 𝑡 𝑐 𝑡 + 𝐵 𝑡 𝑓 𝑡 , (∃𝐴 𝑡 ∈ ℝ𝑁∗𝑁, ∃𝐵 𝑡 ∈ ℝ𝑁∗1)
  33. 33. Copyright © 2022 Morpho, Inc. All Rights Reserved 32 HiPPO 冒頭で線形方程式が出てきた理由はまさにこれ。 以下これを証明し、具体例を与える。 Notations • 𝑓, 𝜇(𝑡), 𝑃𝑛 (𝑡) :入力信号、時刻tでの測度、付随する直交多項式 • 𝑑𝜇(𝑡) = 𝜔(𝑡) 𝑥 𝑑𝑥、また𝜇(𝑡)は確率測度であると仮定(i.e.‫׬‬ 𝑑𝜇(𝑡) = 1) • 𝑝𝑛 (𝑡) 𝑥 = 𝑃𝑛 𝑡 (𝑥)/⟨𝑃𝑛 𝑡 , 𝑃𝑛 𝑡 ⟩(正規化)
  34. 34. Copyright © 2022 Morpho, Inc. All Rights Reserved 33 HiPPO (証明) まず係数𝑐𝑛(𝑡)の構成を思い出せば 𝑐𝑛 𝑡 = 𝑓≤𝑡, 𝑃𝑛 (𝑡) / 𝑃𝑛 (𝑡) , 𝑃𝑛 (𝑡) = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑡 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛 ′ (𝑡) = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥
  35. 35. Copyright © 2022 Morpho, Inc. All Rights Reserved 34 HiPPO (証明) まず係数𝑐𝑛(𝑡)の構成を思い出せば 𝑐𝑛 𝑡 = 𝑓≤𝑡, 𝑃𝑛 (𝑡) / 𝑃𝑛 (𝑡) , 𝑃𝑛 (𝑡) = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑡 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛 ′ (𝑡) = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 𝜕 𝜕𝑡 𝑝𝑛 𝑡 は𝑥についての𝑛次多項式 であるから𝑝0 (𝑡) , … , 𝑝𝑛 (𝑡) の線形和。
  36. 36. Copyright © 2022 Morpho, Inc. All Rights Reserved 35 HiPPO (証明) まず係数𝑐𝑛(𝑡)の構成を思い出せば 𝑐𝑛 𝑡 = 𝑓≤𝑡, 𝑃𝑛 (𝑡) / 𝑃𝑛 (𝑡) , 𝑃𝑛 (𝑡) = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑡 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛 ′ (𝑡) = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 𝜕 𝜕𝑡 𝑝𝑛 𝑡 は𝑥についての𝑛次多項式 であるから𝑝0 (𝑡) , … , 𝑝𝑛 (𝑡) の線形和。 𝜕 𝜕𝑡 𝜔(𝑡)は古典的な直交関数系では 𝜔(𝑡) とディラック𝛿𝑡の線形和。
  37. 37. Copyright © 2022 Morpho, Inc. All Rights Reserved 36 HiPPO (証明) まず係数𝑐𝑛(𝑡)の構成を思い出せば 𝑐𝑛 𝑡 = 𝑓≤𝑡, 𝑃𝑛 (𝑡) / 𝑃𝑛 (𝑡) , 𝑃𝑛 (𝑡) = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑡 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛 ′ (𝑡) = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 𝜕 𝜕𝑡 𝜔(𝑡)は古典的な直交関数系では 𝜔(𝑡) とディラック𝛿𝑡の線形和。 これより第一項は𝑐0, … , 𝑐𝑛の線形和、第二項は𝑐𝑛と𝑓(𝑡)の線形和。(証明終) 𝜕 𝜕𝑡 𝑝𝑛 𝑡 は𝑥についての𝑛次多項式 であるから𝑝0 (𝑡) , … , 𝑝𝑛 (𝑡) の線形和。
  38. 38. Copyright © 2022 Morpho, Inc. All Rights Reserved 37 HiPPO 実際にルジャンドル関数系を用いて実証してみる。 ただしルジャンドル関数系は[−1,1]上の関数系なので、 存在域を𝑡依存になるようスケールしてから適用する。 パターン1: [𝑡 − 𝜃, 𝑡]上に定義(𝜃 ≥ 0は何時刻前までを見るかを表すハイパラ) パターン2: [0, 𝑡]上に定義(過去の履歴をすべて見る) それぞれの場合で𝑐𝑛 ′ (𝑡)がどう書けるか見てみよう。 (図2つは[3]より引用)
  39. 39. Copyright © 2022 Morpho, Inc. All Rights Reserved 38 HiPPO パターン1: [𝑡 − 𝜃, 𝑡]上に定義 このときの正規直交関数系は、ルジャンドル関数系 𝑃𝑛 𝑥 を用いて 𝑝𝑛 𝑡 𝑥 ≔ 2𝑛 + 1 1 2𝑃𝑛 2 𝑥 − 𝑡 𝜃 + 1 𝜕 𝜕𝑡 𝑝𝑛 𝑡 = − 2𝑛 + 1 1 2 2 𝜃 2𝑛 − 1 1/2𝑝𝑛−1 (𝑡) + 2𝑛 − 5 1/2𝑝𝑛−3 (𝑡) + ⋯ またこのとき 𝜔 𝑡 𝑥 = 1 𝜃 1 𝑡−𝜃,𝑡 = 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 − 𝜃 − 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 𝜕 𝜕𝑡 𝜔 𝑡 = −𝛿 𝑥 − 𝑡 − 𝜃 + 𝛿 𝑥 − 𝑡 = 𝛿𝑡 − 𝛿𝑡−𝜃 (図は[3]より引用)
  40. 40. Copyright © 2022 Morpho, Inc. All Rights Reserved 39 HiPPO パターン1: [𝑡 − 𝜃, 𝑡]上に定義 これを先ほどの式に代入すると 第一項 = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 = − 2𝑛 + 1 1 2 2 𝜃 2𝑛 − 1 1 2𝑐𝑛−1 + 2𝑛 − 5 1 2𝑐𝑛−3 + ⋯ 第二項 = න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 = 𝑓 𝑡 𝑝𝑛 𝑡 𝑡 − 𝑓 𝑡 − 𝜃 𝑝𝑛 𝑡 𝑡 − 𝜃 辺々加えてこねこねすれば、 𝑐′ 𝑡 = − 1 𝜃 𝐴𝑐 𝑡 + 1 𝜃 𝐵𝑓 𝑡 𝐴𝑛𝑘 = 2𝑛 + 1 1 2 2𝑘 + 1 1 2 ൝ 1 𝑖𝑓 𝑘 ≤ 𝑛 −1 𝑛−𝑘 𝑖𝑓 𝑘 ≥ 𝑛 , 𝐵𝑛 = 2𝑛 + 1 1 2 (図は[3]より引用)
  41. 41. Copyright © 2022 Morpho, Inc. All Rights Reserved 40 HiPPO パターン1: [𝑡 − 𝜃, 𝑡]上に定義 Def. 測度 1 𝜃 1 𝑡−𝜃,𝑡 から導出されるHiPPOの時間発展式 𝑐′ 𝑡 = − 1 𝜃 𝐴𝑐 𝑡 + 1 𝜃 𝐵𝑓 𝑡 ただし 𝐴𝑛𝑘 = 2𝑛 + 1 1 2 2𝑘 + 1 1 2 ൝ 1 𝑖𝑓 𝑘 ≤ 𝑛 −1 𝑛−𝑘 𝑖𝑓 𝑘 ≥ 𝑛 , 𝐵𝑛 = 2𝑛 + 1 1 2 をHiPPO-LegTと呼ぶ。(translated Legendre) 実はこのODEは少し式変形すると[4]の論文で提案された式と一致する。 しかし[4]の論文ではPadé approximationという別手法を用いて導出。 (図は[3]より引用)
  42. 42. Copyright © 2022 Morpho, Inc. All Rights Reserved 41 HiPPO パターン2: [0, 𝑡]上に定義 このときの正規直交関数系は、ルジャンドル関数系 𝑃𝑛 𝑥 を用いて 𝑝𝑛 𝑡 𝑥 ≔ 2𝑛 + 1 1/2𝑃 𝑛 2𝑥 𝑡 − 1 𝜕 𝜕𝑡 𝑝𝑛 𝑡 = − 2𝑛 + 1 1/2 1 𝑡 𝑛 2𝑛 + 1 −1/2𝑝𝑛 (𝑡) + 2𝑛 − 1 1/2𝑝𝑛−1 (𝑡) + 2𝑛 − 3 1/2𝑝𝑛−2 (𝑡) + ⋯ またこのとき 𝜔 𝑡 𝑥 = 1 𝑡 1 0,𝑡 = 1 𝑡 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 𝜕 𝜕𝑡 𝜔 𝑡 = − 1 𝑡2 1 0,𝑡 + 1 𝑡 𝛿 𝑥 − 𝑡 = 1 𝑡 (−𝜔(𝑡) + 𝛿𝑡) (図は[3]より引用)
  43. 43. Copyright © 2022 Morpho, Inc. All Rights Reserved 42 HiPPO パターン2: [0, 𝑡]上に定義 これを先ほどの式に代入すると 第一項 = න 𝑓 ∗ 𝜕 𝜕𝑡 𝑝𝑛 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 = − 2𝑛 + 1 1 2 1 𝑡 𝑛 2𝑛 + 1 − 1 2𝑐𝑛 + 2𝑛 − 1 1/2𝑐𝑛−1 + 2𝑛 − 3 1/2𝑐𝑛−2 + ⋯ 第二項 = න 𝑓 ∗ 𝑝𝑛 𝑡 ∗ 𝜕 𝜕𝑡 𝜔 𝑡 𝑑𝑥 = − 1 𝑡 𝑐𝑛 𝑡 + 𝑓 𝑡 𝑝𝑛 𝑡 (𝑡) 辺々加えてこねこねすれば、 𝑐′ 𝑡 = − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 𝐴𝑛𝑘 = ൞ 2𝑛 + 1 1 2 2𝑘 + 1 1 2 (𝑛 > 𝑘) 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) , 𝐵𝑛 = 2𝑛 + 1 1 2 (図は[3]より引用)
  44. 44. Copyright © 2022 Morpho, Inc. All Rights Reserved 43 HiPPO パターン2: [0, 𝑡]上に定義 Def. 測度 1 𝑡 1 0,𝑡 から導出されるHiPPOの時間発展式 𝑐′ 𝑡 = − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 ただし 𝐴𝑛𝑘 = ൞ 2𝑛 + 1 1 2 2𝑘 + 1 1 2 (𝑛 > 𝑘) 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) , 𝐵𝑛 = 2𝑛 + 1 1 2 をHiPPO-LegSと呼ぶ。(scaled Legendre) 実はこれがまさに本論文で提案する新手法に他ならない! (図は[3]より引用)
  45. 45. Copyright © 2022 Morpho, Inc. All Rights Reserved 44 HiPPO ラゲール、チェビシェフ、エルミート等、他の直交関数系に対しても導出が可能。 HiPPO-LegSは[0,t]のすべての時刻を見る点で直感的にHiPPO-LegTよりも優 れているが、以降でこの式が多くの嬉しい性質を満たすことを見る。 ここまでのまとめ • ℎ𝑖𝑝𝑝𝑜の出力(𝑐0, 𝑐1, … , 𝑐𝑁−1)の時間変化はlinear ODEで書ける。 • ODEの係数行列は陽に書けて実際に計算可能。 • HiPPOの枠組みで既存手法を導出可能(HiPPO-LegT)。 • HiPPO-LegSという新しい時間発展式を提案。
  46. 46. Copyright © 2022 Morpho, Inc. All Rights Reserved 45 HiPPO 最後にHiPPO-LegSの持つ良い性質を見ていこう。 スペースの都合上、ここでは時間スケールに依存しないことだけ見る。 その他の性質は最後に結果のみを列挙する。
  47. 47. Copyright © 2022 Morpho, Inc. All Rights Reserved 46 HiPPO (証明) 前述のODEを計算するにあたり、まずは離散化をしないといけない。 𝑐′ 𝑡 = − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 の両辺を積分して、 𝑐 𝑡 + Δ𝑡 − 𝑐 𝑡 = න 𝑡 𝑡+Δ𝑡 − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 𝑑𝑡 ≈ Δ𝑡 2 − 1 𝑡 𝐴𝑐 𝑡 + 1 𝑡 𝐵𝑓 𝑡 + − 1 𝑡 + Δ𝑡 𝐴𝑐 𝑡 + Δ𝑡 + 1 𝑡 + Δ𝑡 𝐵𝑓 𝑡 + Δ𝑡 Lemma ([3, Appendix B]) HiPPO-LegSは時間スケールに依存しない。
  48. 48. Copyright © 2022 Morpho, Inc. All Rights Reserved 47 HiPPO 辺々整理すると 𝐼 + Δ𝑡 2 𝑡 + Δ𝑡 𝐴 𝑐 𝑡 + Δ𝑡 = 𝐼 − Δ𝑡 2𝑡 𝐴 𝑐 𝑡 + Δ𝑡 2 𝑡 + Δ𝑡 + Δ𝑡 2𝑡 𝐵𝑓(𝑡) なお𝑓 𝑡 + Δ𝑡 = 𝑓(𝑡)の仮定を暗黙に使った。 ここで𝑡 = 𝑘Δ𝑡, 𝑐𝑘 ≔ 𝑐 𝑘Δ𝑡 , 𝑓𝑘 ≔ 𝑓(𝑘Δ𝑡)とすれば、 𝐼 + 1 2(𝑘 + 1) 𝐴 𝑐𝑘+1 = 𝐼 − 1 2𝑘 𝐴 𝑐𝑘 + 1 2 𝑘 + 1 + 1 2𝑘 𝐵𝑓(𝑡) ⇒どこにもΔ𝑡が出てこない!(証明終わり) (HiPPO-LegTなど他の直交関数系だとこうはならない)
  49. 49. Copyright © 2022 Morpho, Inc. All Rights Reserved 48 HiPPO 上の結果と合わせて、他の性質もまとめて理論説明を終わる。 ここまでのまとめ • HiPPO-LegSは時間スケールに依存しない。(ドメインシフトに強い) • HiPPOの1回のdiscretized ODE計算はO(N)。 • 𝑘 ∈ 𝑁: fixedおよび∀𝑙 > 𝑘に対して 𝜕𝑐𝑙+1 𝜕𝑓𝑘 = 𝑂(1/𝑙) (勾配消失・爆発しない!) • 𝑓𝑥≤𝑡の𝑆𝑝𝑎𝑛⟨𝑃0, … , 𝑃𝑁−1⟩への射影を𝑔(𝑡)としたとき • 𝑓が𝐿-Lipschitzなら 𝑓𝑥≤𝑡 − 𝑔(𝑡) = 𝑂 𝑡𝐿 𝑁 • 𝑓の𝑘回微分が有界なら 𝑓𝑥≤𝑡 − 𝑔(𝑡) = 𝑂 𝑡𝑘𝑁−𝑘+1/2
  50. 50. Copyright © 2022 Morpho, Inc. All Rights Reserved 49 HiPPO 実験 HiPPOの離散漸化式をRNNに組み込んで性能評価してみる。 hidden state ℎ𝑡の履歴を記憶させるよう下図のモデル設計を採用。 ([3]より引用)
  51. 51. Copyright © 2022 Morpho, Inc. All Rights Reserved 50 HiPPO 実験 タスク1: Permuted MNIST ([3]より引用)
  52. 52. Copyright © 2022 Morpho, Inc. All Rights Reserved 51 HiPPO 実験 タスク2: Character Trajectory Classification ペン先の3次元速度情報から書いている文字を当てるタスク。 サンプリングレートを変えてドメインシフトを再現しているが、 HiPPO-LegSは影響を受けていない。 ([3]より引用)
  53. 53. Copyright © 2022 Morpho, Inc. All Rights Reserved 52 HiPPO 実験 タスク3: Copying ([3]より引用)
  54. 54. Copyright © 2022 Morpho, Inc. All Rights Reserved 53 LSSL Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (NeurIPS 2021)
  55. 55. Copyright © 2022 Morpho, Inc. All Rights Reserved 54 LSSL 概要 HiPPOを以下のように改良する。 • ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) の形に増強。 • 𝐴, 𝐵, 𝐶, 𝐷を学習パラメータに変更。 • 上記連立方程式がCNN/RNNの要素を含むことを証明。
  56. 56. Copyright © 2022 Morpho, Inc. All Rights Reserved 55 LSSL 手法 動機は論文に書いてないが、HiPPOの式を ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) の形に増強する。 (状態空間モデルの方程式を意識してると思われる) しれっとA,B,C,Dは時間依存しないことになってる? HiPPO-LegSは係数行列は時間依存してたが。。。。。。 t → ∞ でAはほぼ変化しないので定数とみなしてるのかも。
  57. 57. Copyright © 2022 Morpho, Inc. All Rights Reserved 56 LSSL 手法 𝑦 𝑡 が𝑥 𝑡 と𝑢(𝑡)の線形和であることに注目する。 ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) これをLSSL(Linear State Space Layer)と呼ぶ。
  58. 58. Copyright © 2022 Morpho, Inc. All Rights Reserved 57 LSSL 手法 𝑦 𝑡 が𝑥 𝑡 と𝑢(𝑡)の線形和であることに注目する。 ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) これをLSSL(Linear State Space Layer)と呼ぶ。 以下この方程式が次の性質を持つことを順にみていこう。 1.線形であるために、RNNより高速に計算可能。 2.線形だと貧弱な気がするが、実は十分な表現力を持つ。
  59. 59. Copyright © 2022 Morpho, Inc. All Rights Reserved 58 LSSL 1.高速に計算可能 まずLSSLをbilinear離散化すると、特に第1式について積分して 𝑥’ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 𝑥 𝑡 + Δ𝑡 − 𝑥 𝑡 = Δ𝑡 2 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 + 𝐴𝑥 𝑡 + Δ𝑡 + 𝐵𝑢 𝑡 + Δ𝑡 𝑥 𝑡 + Δ𝑡 = ҧ 𝐴𝑥 𝑡 + ത 𝐵𝑢 𝑡 ただし ҧ 𝐴 = 𝐼 − Δ 2 𝐴 −1 𝐼 + Δ 2 𝐴 , ത 𝐵 = 𝐼 − Δ 2 𝐴 −1 Δ𝐵
  60. 60. Copyright © 2022 Morpho, Inc. All Rights Reserved 59 LSSL 1.高速に計算可能 この離散化式 ൝ 𝑥𝑘 = ҧ 𝐴𝑥𝑘−1 + ത 𝐵𝑢𝑘 𝑦𝑘 = ҧ 𝐶𝑥𝑘 + ഥ 𝐷𝑢𝑘 から𝑥を削除すると、𝑥−1 = 0として 𝑦0 = ҧ 𝐶 ത 𝐵𝑢0 + ഥ 𝐷𝑢0 𝑦1 = ҧ 𝐶 ҧ 𝐴 ത 𝐵𝑢0 + ത 𝐵𝑢1 + ഥ 𝐷𝑢1 𝑦2 = ҧ 𝐶 ҧ 𝐴 ҧ 𝐴 ത 𝐵𝑢0 + ത 𝐵𝑢1 + ത 𝐵𝑢2 + ഥ 𝐷𝑢2 … … … 𝑦𝑘 = ҧ 𝐶 ҧ 𝐴 𝑘 ത 𝐵𝑢0 + ҧ 𝐶 ҧ 𝐴 𝑘−1 ത 𝐵𝑢1 + ⋯ + ҧ 𝐶 ത 𝐵𝑢𝑘 + ഥ 𝐷𝑢𝑘
  61. 61. Copyright © 2022 Morpho, Inc. All Rights Reserved 60 LSSL 1.高速に計算可能 ഥ 𝐷はお尻にしか付かないのでഥ 𝐷 = 0として無視しよう。すると 𝑦𝑘 = ҧ 𝐶 ҧ 𝐴 𝑘 ത 𝐵𝑢0 + ҧ 𝐶 ҧ 𝐴 𝑘−1 ത 𝐵𝑢1 + ⋯ + ҧ 𝐶 ത 𝐵𝑢𝑘 となり、この式はまさに𝑦 = 𝐾𝐿( ҧ 𝐴, ത 𝐵, ҧ 𝐶) ∗ 𝑢のconvolutionに他ならない。 𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶 ≔ ( ҧ 𝐶 ത 𝐵, ҧ 𝐶 ҧ 𝐴 ത 𝐵, … , ҧ 𝐶 ҧ 𝐴𝐿−1 ത 𝐵) ここで𝐿はシーケンス長を表す。 これよりrecurrenceが不要になり、計算は高速。
  62. 62. Copyright © 2022 Morpho, Inc. All Rights Reserved 61 LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.1]) LSSLはbackward-Eulerで離散化した場合、 RNNのgating mechanismを包含する。 (証明) LSSLの第一式 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) をbackward-Eulerで離散化すると 𝑥 𝑡 + Δ𝑡 − 𝑥 𝑡 = න 𝑡 𝑡+Δ𝑡 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 𝑑𝑡 ≈ Δ𝑡 𝐴𝑥 𝑡 + Δ𝑡 + 𝐵𝑢(𝑡 + Δ𝑡)
  63. 63. Copyright © 2022 Morpho, Inc. All Rights Reserved 62 LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.1]) LSSLはbackward-Eulerで離散化した場合、 RNNのgating mechanismを包含する。 𝑥𝑘 ≔ 𝑥 𝑡 , 𝑥𝑘+1 ≔ 𝑥 𝑡 + Δ𝑡 , 𝑢𝑘+1 ≔ 𝑢(𝑡 + Δ𝑡)とし、さらにΔ𝑡 = 𝑒𝑧とおけば、 𝑥𝑘+1 − 𝑥𝑘 ≈ 𝑒𝑧 𝐴𝑥𝑘+1 + 𝐵𝑢𝑘+1 𝑥𝑘+1 ≈ 1 − 𝐴𝑒𝑧 1 + 𝑒𝑧 𝑥𝑘 + 𝐵𝑒𝑧 1 + 𝑒𝑧 𝑢𝑘 ここで𝐴 = 𝐵 = 1とすれば、 𝑥𝑘+1 ≈ 1 − 𝜎 𝑧 𝑥𝑘 + 𝜎 𝑧 𝑢𝑘となり、 これはgating mechanismに他ならない。(証明終わり)
  64. 64. Copyright © 2022 Morpho, Inc. All Rights Reserved 63 LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.2]) 𝑓(𝑡, 𝑥)がxについて局所Lipstizsである非線形関数としたとき、 無限にLSSLをstackしたモデルは𝑥’ 𝑡 = −𝑥 𝑡 + 𝑓(𝑡, 𝑥(𝑡))を解ける。 (証明概略) LSSLの線形部分をstackすると、それが実質ピカールの逐次近似 法を回していることになっている。 非線形部分𝑓はLSSL間にpointwise non-linearityな層を挟むことで再現す る。(証明終わり) ※この命題は本筋には使わない。詳細は各自論文参照。
  65. 65. Copyright © 2022 Morpho, Inc. All Rights Reserved 64 LSSL ここまでのまとめ • HiPPOにさらに線形方程式を追加したLSSLを提案。 • LSSLはconvolutionとして解釈可能なため高速。 • LSSLはRNNを含み、non-linear ODEを解くだけの能力を持つ。
  66. 66. Copyright © 2022 Morpho, Inc. All Rights Reserved 65 LSSL LSSLがHiPPOより真に優位であることは分かった。 次にこれを実際にどう学習に組み込むかを見ていく。 特に • Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい。 • Convolution:𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶 ≔ ( ҧ 𝐶 ത 𝐵, ҧ 𝐶 ҧ 𝐴 ത 𝐵, … , ҧ 𝐶 ҧ 𝐴𝐿−1 ത 𝐵)を如何に高速計算するか。 を調べたい。
  67. 67. Copyright © 2022 Morpho, Inc. All Rights Reserved 66 LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? Convolution:𝑲𝑳 ഥ 𝑨, ഥ 𝑩, ഥ 𝑪 ≔ (ഥ 𝑪ഥ 𝑩, ഥ 𝑪ഥ 𝑨ഥ 𝑩, … , ഥ 𝑪ഥ 𝑨𝑳−𝟏 ഥ 𝑩)を如何に高速計算するか この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁3 𝐿)かかる。 もっと速く計算できないだろうか?
  68. 68. Copyright © 2022 Morpho, Inc. All Rights Reserved 67 LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? Convolution:𝑲𝑳 𝑨, 𝑩, 𝑪 ≔ (𝑪𝑩, 𝑪𝑨𝑩, … , 𝑪𝑨𝑳−𝟏𝑩)を如何に高速計算するか この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁3 𝐿)かかる。 もっと速く計算できないだろうか? ここで残念なお知らせ LSSLの論文でこの考察をしているが、 その結果はお世辞にもきれいとは言えない。 しかも計算は非常に不安定。
  69. 69. Copyright © 2022 Morpho, Inc. All Rights Reserved 68 LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? Convolution:𝑲𝑳 𝑨, 𝑩, 𝑪 ≔ (𝑪𝑩, 𝑪𝑨𝑩, … , 𝑪𝑨𝑳−𝟏𝑩)を如何に高速計算するか この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁3 𝐿)かかる。 もっと速く計算できないだろうか? これらの問題点は S4の論文にて 1年越しに解決!
  70. 70. Copyright © 2022 Morpho, Inc. All Rights Reserved 69 S4 Efficiently Modeling Long Sequences with Structured State Spaces (ICLR 2022 Oral)
  71. 71. Copyright © 2022 Morpho, Inc. All Rights Reserved 70 S4 概要 LSSLで消化不良だった • Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい。 • Convolution:𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶 ≔ ( ҧ 𝐶 ത 𝐵, ҧ 𝐶 ҧ 𝐴 ത 𝐵, … , ҧ 𝐶 ҧ 𝐴𝐿−1 ത 𝐵)を如何に高速計算するか。 を解決する。
  72. 72. Copyright © 2022 Morpho, Inc. All Rights Reserved 71 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい 一般的なHiPPO行列の形を導出するのは難しい。(LSSLの論文ではそれをやって大変汚いことに) そこで 「計算しやすさ」 と 「HiPPO-LegT/LegSを含む」 ことを条件に、学習する行列Aのクラスを決める。
  73. 73. Copyright © 2022 Morpho, Inc. All Rights Reserved 72 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Def. 行列𝐴 ∈ 𝑅𝑛∗𝑛が 𝐴 = 𝐹 − 𝑝𝑞𝑇 (𝐹: 𝑛𝑜𝑟𝑚𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗𝑘 𝑘 ≪ 𝑛 ) と書けるとき、𝐴はNPLR(Normal Plus Low-Rank)表現を持つという。 (Plusと言いつつマイナスにしているのは、本スライドでの説明の都合による) Fact 以下は同値: 1. 𝐹はnormal (i.e. 𝐹𝐹∗ = 𝐹∗𝐹) 2. 𝐹はユニタリ行列で対角化可能
  74. 74. Copyright © 2022 Morpho, Inc. All Rights Reserved 73 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 (証明) 以下ではHiPPO LegSのみ見ていく。このとき行列𝐴は 𝐴𝑛𝑘 = − ൞ 2𝑛 + 1 1 2 2𝑘 + 1 1 2 (𝑛 > 𝑘) 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) と書けた。
  75. 75. Copyright © 2022 Morpho, Inc. All Rights Reserved 74 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 ここで𝑝 = 2𝑛+1 2 1 2 𝑛 とすると、 𝑝𝑝𝑇 𝑛𝑘 = 1 2 2𝑛 + 1 1 2 2𝑘 + 1 1 2 であり、 𝐴 + 𝑝𝑝𝑇 𝑛𝑘 = − 1 2 2𝑛 + 1 1 2 2𝑘 + 1 1 2 (𝑛 > 𝑘) ∗∗∗ 略 ∗∗∗ (𝑛 = 𝑘) − 1 2 2𝑛 + 1 1 2 2𝑘 + 1 1 2(𝑛 < 𝑘)
  76. 76. Copyright © 2022 Morpho, Inc. All Rights Reserved 75 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 すなわち 𝐴 + 𝑝𝑝𝑇 = 𝑠𝑘𝑒𝑤_𝑠𝑦𝑚𝑚𝑒𝑡𝑟𝑖𝑐 + 𝑘𝐼, ∃𝑘 ∈ ℝ の形になっており、特に右辺は正規行列。(証明終わり)
  77. 77. Copyright © 2022 Morpho, Inc. All Rights Reserved 76 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい これを踏まえて行列AはNPLRの中で学習させることを考える。 が、実はさらにクラスを制限しても問題ないことを次に示す。 Def. 行列𝐴 ∈ 𝑅𝑛∗𝑛が 𝐴 = Λ − 𝑝𝑞𝑇 (Λ: 𝑑𝑖𝑎𝑔𝑜𝑛𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗𝑘 𝑘 ≪ 𝑛 ) と書けるとき、𝐴はDPLR(Diagonal Plus Low-Rank)表現を持つという。
  78. 78. Copyright © 2022 Morpho, Inc. All Rights Reserved 77 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Lemma 3.1]) HiPPO行列に共役な作用を施しても出力は不変。 (証明) 主張がやや不明瞭だが、証明を見れば意味が分かる。 ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) に対して 𝐴, 𝐵, 𝐶, 𝐷 → (𝑉−1𝐴𝑉, 𝑉−1𝐵, 𝐶𝑉, 𝐷)の変換を施すと、 ൝ 𝑥′ 𝑡 = 𝑉−1𝐴𝑉𝑥 𝑡 + 𝑉−1𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑉𝑥 𝑡 + 𝐷𝑢(𝑡) ↔ ቊ 𝑉𝑥′ 𝑡 = 𝐴𝑉𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑉𝑥 𝑡 + 𝐷𝑢(𝑡)
  79. 79. Copyright © 2022 Morpho, Inc. All Rights Reserved 78 S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Lemma 3.1]) HiPPO行列に共役な作用を施しても出力は不変。 すなわち𝑉の共役な作用が𝐴にかかっても、𝐵, 𝐶を適切に変換すれば、 作用の影響は潜在変数の変数変換にとどまる。(証明終わり) これより行列𝐴をDPLRの中で学習させるとしても問題ない。
  80. 80. Copyright © 2022 Morpho, Inc. All Rights Reserved 79 S4 ここまでのまとめ 𝐴 = Λ − 𝑝𝑞𝑇 として、Λ, 𝑝, 𝑞を学習させることにする。 これにより求まる𝐴の属する空間は、 古典的な直交関数形に対するHiPPO行列たちを含む。
  81. 81. Copyright © 2022 Morpho, Inc. All Rights Reserved 80 S4 𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶 ≔ ( ҧ 𝐶 ത 𝐵, ҧ 𝐶 ҧ 𝐴 ത 𝐵, … , ҧ 𝐶 ҧ 𝐴𝐿−1 ത 𝐵)の高速計算 ここが本論文の山場。 なんと上記のconvolutionカーネル計算を、 愚直計算の𝑂(𝑁3𝐿)からなんと ෨ 𝑂(𝑁 + 𝐿)にまで落としてしまう。 超絶技巧が盛りだくさんなので、step-by-stepに追っていこう。
  82. 82. Copyright © 2022 Morpho, Inc. All Rights Reserved 81 S4 STEP0. 先ほど述べたように、 𝐴 = Λ − 𝑝𝑞𝑇 (Λ: 𝑑𝑖𝑎𝑔𝑜𝑛𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗1) と置く。説明簡単化のため、𝑝, 𝑞は𝑛 ∗ 1行列とする。(HiPPO-LegSはそう) また統一性のため ቊ 𝑥′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑦 𝑡 = 𝐶∗𝑥 𝑡 + 𝐷𝑢(𝑡) のように𝐶を転置して、𝐶が𝐵, 𝑝, 𝑞と同じℝ𝑛∗1の元であるようにする。
  83. 83. Copyright © 2022 Morpho, Inc. All Rights Reserved 82 S4 STEP1. 𝐾𝐿 ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ ≔ ( ҧ 𝐶∗ ത 𝐵, ҧ 𝐶∗ ҧ 𝐴 ത 𝐵, … , ҧ 𝐶∗ ҧ 𝐴𝐿−1 ത 𝐵) を直接求めるのではなく、それのz変換もどき ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ ≔ ෍ 𝑖=0 𝐿−1 ҧ 𝐶∗ ҧ 𝐴𝑖 ത 𝐵𝑧𝑖 ∈ ℂ[𝑧] を求めることを考える。 ෡ 𝐾𝐿から𝐾𝐿を導出するのは、zに1のべき根を突っ込んでiFFTにより𝑂(𝐿 log 𝐿)
  84. 84. Copyright © 2022 Morpho, Inc. All Rights Reserved 83 S4 STEP2. Lemma ([6, Lemma C.3]) ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ = 2 1 + 𝑧 ሚ 𝐶∗𝑅 𝑧 𝐵 − ሚ 𝐶∗𝑅 𝑧 𝑝 1 + 𝑞∗𝑅 𝑧 𝑝 −1𝑞∗𝑅 𝑧 𝐵 ただし ሚ 𝐶 = 𝐶 𝐼 − ҧ 𝐴𝐿 , 𝑅 𝑧; Λ = 2 Δ 1 − 𝑧 1 + 𝑧 − Λ −1 (証明) 形式べき級数を用いて、mod 𝑧𝐿 の下で ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ ≔ ෍ 𝑖=0 𝐿−1 ҧ 𝐶∗ ҧ 𝐴𝑖 ത 𝐵𝑧𝑖 = ҧ 𝐶∗ 𝐼 − ҧ 𝐴𝐿 𝐼 − ҧ 𝐴𝑧 −1 ത 𝐵 = ሚ 𝐶∗ 𝐼 − ҧ 𝐴𝑧 −1 ത 𝐵
  85. 85. Copyright © 2022 Morpho, Inc. All Rights Reserved 84 S4 また、LSSLの離散化手続きを思い出すと ҧ 𝐴 = 𝐼 − Δ 2 𝐴 −1 𝐼 + Δ 2 𝐴 , ത 𝐵 = 𝐼 − Δ 2 𝐴 −1 Δ𝐵 であるが、これを前述の式に代入すると以下を得る(詳細は[6, Lemma C.4]) ሚ 𝐶∗ 𝐼 − ҧ 𝐴𝑧 −1 ത 𝐵 = 2Δ 1 + 𝑧 ሚ 𝐶∗ 2 1 − 𝑧 1 + 𝑧 𝐼 − Δ𝐴 −1 𝐵
  86. 86. Copyright © 2022 Morpho, Inc. All Rights Reserved 85 S4 ここでさらに𝐴 = Λ − 𝑝𝑞𝑇なことを思い出すと ሚ 𝐶∗ 𝐼 − ҧ 𝐴𝑧 −1 ത 𝐵 = 2Δ 1 + 𝑧 ሚ 𝐶∗ 2 1 − 𝑧 1 + 𝑧 𝐼 − Δ Λ − 𝑝𝑞∗ −1 𝐵 = 2 1 + 𝑧 ሚ 𝐶∗ 2 Δ 1 − 𝑧 1 + 𝑧 𝐼 − Λ + 𝑝𝑞∗ −1 𝐵 = 2 1 + 𝑧 ሚ 𝐶∗𝑅 𝑧 𝐵 − ሚ 𝐶∗𝑅 𝑧 𝑝 1 + 𝑞∗𝑅 𝑧 𝑝 −1 𝑞∗𝑅 𝑧 𝐵 なお最後の等号はWoodbury Identityから従う。(証明終わり) Fact (Woodbury Identity) 任意の行列𝐴, 𝑃, 𝑄に対して以下が成り立つ 𝐴 + 𝑈𝑉∗ −1 = 𝐴−1 − 𝐴−1𝑈 𝐼 + 𝑉∗𝐴−1𝑈 −1𝑉∗𝐴−1 diagonal
  87. 87. Copyright © 2022 Morpho, Inc. All Rights Reserved 86 S4 STEP2. 求めた式は一見煩雑になっただけに見えるが、よく見ると赤線部分 ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ = 2 1 + 𝑧 ሚ 𝐶∗ 𝑅 𝑧 𝐵 − ሚ 𝐶∗ 𝑅 𝑧 𝑝 1 + 𝑞∗ 𝑅 𝑧 𝑝 −1 𝑞∗ 𝑅 𝑧 𝐵 はすべてスカラーであり、かつ𝑅 𝑧; Λ = 2 Δ 1−𝑧 1+𝑧 − Λ −1 は対角行列。 すなわち上の計算は登場する行列たちが既知なら𝑂(𝑁)で求まる。
  88. 88. Copyright © 2022 Morpho, Inc. All Rights Reserved 87 S4 STEP3. よってあとは新規の登場人物たち、とくに ሚ 𝐶 = 𝐶 𝐼 − ҧ 𝐴𝐿 , 𝑅 𝑧; Λ = 2 Δ 1 − 𝑧 1 + 𝑧 − Λ −1 の2つが高速に求められれば良い。 前者は発想の転換で、𝐶ではなく ሚ 𝐶を最初から学習させることにすれば解決。
  89. 89. Copyright © 2022 Morpho, Inc. All Rights Reserved 88 S4 STEP3. 𝑅 𝑧; Λ = 2 Δ 1 − 𝑧 1 + 𝑧 − Λ −1 だが、一見対角行列なので𝑂(𝑁)で計算可能で、何も問題ないように見える。 しかしSTEP1を見直すと、我々は ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ をすべての1の𝐿乗根に対して求める必要がある。 よってこのままだと𝑂(𝑁𝐿)かかってしまいまずい。
  90. 90. Copyright © 2022 Morpho, Inc. All Rights Reserved 89 S4 STEP3. どのみち𝑅(𝑧)を𝑂(𝑁)より早く求めたとしても、STEP2の計算をすべての1の𝐿 乗根に対して行うと𝑂(𝑁𝐿)かかってしまう。 そこで少し視点を変えて、一般に赤線部分 𝑉∗𝑅 𝑧 𝑈, (∀𝑈, 𝑉 ∈ ℝ𝑛∗1) をすべての𝑧 ∈ {1の𝐿乗根}に対して一括で ෨ 𝑂(𝑁 + 𝐿)で求めることを考える。 実は𝑅(𝑧)の特殊構造により、これが可能である。
  91. 91. Copyright © 2022 Morpho, Inc. All Rights Reserved 90 S4 STEP3. Def. K ∈ ℝ𝑀∗𝑁であって、 𝐾𝑖𝑗 = 1 𝜔𝑖 − 𝜆𝑗 , (𝜔𝑖, 𝜆𝑗 ∈ ℂ) と書けるものをCauchy Kernelと呼ぶ。 Fact [7] Cauchy Kernelの行列ベクトル積にかかる計算量は ൞ 𝑂 𝑀 + 𝑁 log2 𝑀 + 𝑁 , 𝑒𝑥𝑎𝑐𝑡 𝑎𝑟𝑖𝑡ℎ𝑚𝑒𝑡𝑖𝑐 𝑂 𝑀 + 𝑁 log 𝑀 + 𝑁 log 1 𝜖 , 𝑛𝑢𝑚𝑒𝑟𝑖𝑐𝑎𝑙𝑙𝑦 𝑡𝑜 𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛 𝜖
  92. 92. Copyright © 2022 Morpho, Inc. All Rights Reserved 91 S4 STEP3. これを踏まえると 𝑅 𝑧; Λ = 2 Δ 1 − 𝑧 1 + 𝑧 − Λ −1 はまさにCauchy Kernelに他ならない。 ゆえにすべての𝑧 ∈ {1の𝐿乗根}に対して、 ෡ 𝐾𝐿 𝑧; ҧ 𝐴, ത 𝐵, ҧ 𝐶∗ の計算は一括で ෨ 𝑂(𝑁 + 𝐿)で終わる。 STEP1のiFFTは𝑂(𝐿 log 𝐿)なので、全体としても ෨ 𝑂(𝑁 + 𝐿)で計算が完了する。 (おわり!)
  93. 93. 92 Copyright © 2022 Morpho, Inc. All Rights Reserved. • 時系列モデリングの新手法HiPPOを提案。 • HiPPOを状態空間モデルの方程式に組み込み、高速なconvolution計算を実現。 • Path-Xタスクで世界初の推論成功を達成。 「所感」 • 概念基盤がかなりしっかりしていて、かつ汎用性が高い。 • 後続研究にS4をaudio generationやvideo classificationに使用した例あり。 「おまけ」 公式実装:https://github.com/HazyResearch/state-spaces 解説付きJax実装:https://srush.github.io/annotated-s4 まとめ まとめ
  94. 94. 93 Copyright © 2022 Morpho, Inc. All Rights Reserved. [1] Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. Long range arena : A benchmark for efficient transformers. In International Conference on Learning Representations, 2021. [2] 黒田成俊. 関数解析. 共立出版. 1980. [3] Albert Gu, Tri Dao, Stefano Ermon, Atri Rudra, and Christopher R´e. Hippo: Recurrent memory with optimal polynomial projections. In Advances in Neural Information Processing Systems, pages 1474-1487, 2020. [4] Aaron Voelker, Ivana Kajić, and Chris Eliasmith. Legendre memory units: Continuous- time representation in recurrent neural networks. In Advances in Neural Information Processing Systems, pages 15544–15553, 2019. [5] Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, and Christopher R´e. Combining recurrent, convolutional, and continuous-time models with the structured learnable linear state space layer. In Advances in Neural Information Processing Systems, pages 572-585, 2021. まとめ 参考文献
  95. 95. 94 Copyright © 2022 Morpho, Inc. All Rights Reserved. [6] Albert Gu, Karan Goel, and Christopher R´e. Efficiently modeling long sequences with structured state spaces. In International Conference on Learning Representations, 2022. [7] Victor Pan. Structured matrices and polynomials: unified superfast algorithms. Springer Science & Business Media, 2001. まとめ 参考文献
  96. 96. Thank you

More Related Content

Related Books

Free with a 30 day trial from Scribd

See all

×