Successfully reported this slideshow.
Your SlideShare is downloading. ×

【DL輪読会】Transformers are Sample Efficient World Models

Ad
Ad
Ad
Ad
Ad
Ad
Ad
Ad
Ad
Ad
Ad

Check these out next

1 of 21 Ad
Advertisement

More Related Content

More from Deep Learning JP (20)

Recently uploaded (20)

Advertisement

【DL輪読会】Transformers are Sample Efficient World Models

  1. 1. 1 DEEP LEARNING JP [DL Papers] http://deeplearning.jp/ DL輪読会:Transformers are Sample Efficient World Models Ryoichi Takase
  2. 2. 書誌情報 2 ※注釈無しの図は本論文から抜粋 採録:ICLR2023 under review 概要: Discrete autoencoderとTransformerを組み合わせた世界モデルを提案 モデルベース強化学習を用いてAtari100kベンチマークで高性能を発揮
  3. 3. 背景 3 世界モデル [1] モデルベース強化学習であり、世界モデル内(想像の中)で方策を学習 → 性能向上に十分な回数を試行可能なためサンプル効率が良い 強化学習の課題: 高性能を発揮するが、学習には非常に多くの経験データを必要とする → サンプル効率が悪い [1] Ha, David, and Jürgen Schmidhuber. "World models." 2018. 観測データを直接扱うとタスクと無関係な情報の変化で性能が劣化する 例)ゲーム画面の背景画像など → 汎化性能が低い 潜在変数空間における状態遷移のモデル化 → タスクの本質を学習することで汎化性能が向上 想像の中で学習するため世界モデルの精度が性能に直結する
  4. 4. 研究目的 4 関連研究:DreamerV2 [2] 世界モデルベースの強化学習アルゴリズム Atari環境でRainbowを上回る性能を発揮 関連研究:Decision Transformer [3] Transformerモデルが自然言語処理の枠組みを超えて強化学習で高性能を発揮 研究目的: Transformerの系列モデリング技術を応用して高精度な世界モデルを構築 得られた世界モデルを用いて高性能なモデルベース強化学習を実現 [3] Chen, Lili, et al. "Decision transformer: Reinforcement learning via sequence modeling." 2021. [2] Hafner, Danijar, et al. "Mastering atari with discrete world models." 2020.
  5. 5. 提案する世界モデルの概要 5 ①エンコーダ𝐸が初期フレーム𝑥0をトークン𝑧0に変換(実際の環境情報で初期化) ②デコーダ𝐷がトークン𝑧𝑡を画像𝑥𝑡 に再構成 ③方策𝜋が再構成画像𝑥𝑡から行動𝑎𝑡をサンプリング 次の状態𝑥𝑡+1、報酬𝑟𝑡、エピソードの終了𝑑𝑡を予測 ④Transformerが報酬𝑟、エピソードの終了𝑑、次のトークン𝑧𝑡+1を予測 ① ② ③ ④ 提案手法: IRIS (Imagination with auto-Regression over an Inner Speech) Discrete autoencoderとTransformerを組み合わせて世界モデルを構築
  6. 6. ① ② ③ ④ Discrete autoencoder 6 ① エンコーダ 𝑬: 入力画像𝑥𝑡をvocab size 𝑁 のトークンに変換 ② デコーダ 𝑫: CNNデコーダを用いてトークンを画像𝑥に再構成 Discrete autoencoderの学習: 収集したフレームデータを使用 損失関数としてL2 reconstruction、commitment、perceptualを等しく重みづけ Convolutional Neural Network (CNN)により入力画像𝑥𝑡を出力𝑦𝑡に変換 トークン𝑧𝑡を𝑧𝑡 𝑘 = argmin𝑖 𝑦𝑡 𝑘 − 𝑒𝑖 で選択 (ℰ = 𝑒𝑖 𝑖=1 𝑁 :対応する埋め込み表)
  7. 7. Transformer 7 Transformerの学習: 損失関数としてTransitionとTerminationには交差エントロピー誤差、 Rewardには交差エントロピー誤差もしくは平均二乗誤差を使用 ④ Transformer 𝑮: Discrete autoencoderで得たトークンを用いて、潜在空間での状態遷移モデルを学習 時刻𝑡までのトークン𝑧≤𝑡と行動𝑎≤𝑡に加えて 時刻𝑡 + 1で既に予測した も使用して予測 ① ② ③ ④
  8. 8. 学習手順 8 (B) 世界モデルの学習: 1. 学習データを𝒟からサンプリング 2. Discrete autoencoderを更新 3. Transformerを更新 (C) 方策の学習: 1. 初期フレームを𝒟からサンプリング 2. 世界モデル内で経験データを収集 3. 方策・価値関数を更新 (B) → (C) → 環境との相互作用 ※目的関数とハイパーパラメータはDreamerV2を参考に設定 学習ループ→ 世界モデルの更新 方策の更新 (A) → (A) 環境との相互作用: 実環境で軌跡データを収集して𝒟に格納
  9. 9. ベンチマーク環境 9 Atari100kベンチマーク: 26種類のAtari ゲームで構成 エージェントは各環境で100kステップの行動が可能 → 人間のゲームプレイ約2時間に相当する ゲーム例:Frostbite (左) と Krull (右)
  10. 10. ベースラインアルゴリズム 10 先読み検索の有無でベースラインを区別: IRIS(提案手法)はMonte Carlo Tree Searchとの組み合わせが可能だが、 本論文では先読み検索なしの手法を比較対象として設定 先読み検索なし: SimPLe [5]、CURL [6]、DrQ [7]、SPR [8] 先読み検索あり: MuZero [9]、EfficientZero [10] [5] Kaiser, Łukasz, et al. "Model Based Reinforcement Learning for Atari." 2019. [6] Srinivas, Aravind, Michael Laskin, and Pieter Abbeel. "CURL: Contrastive Unsupervised Representations for Reinforcement Learning." 2020. [7] Yarats, Denis, Ilya Kostrikov, and Rob Fergus. "Image augmentation is all you need: Regularizing deep reinforcement learning from pixels." 2020. [8] Schwarzer, Max, et al. "Data-efficient reinforcement learning with self-predictive representations." 2020. [9] Schrittwieser, Julian, et al. "Mastering atari, go, chess and shogi by planning with a learned model." 2020. [10] Ye, Weirui, et al. "Mastering atari games with limited data." 2021.
  11. 11. 数値実験の評価方法 11 層別ブーストラップによる信頼区間の推定: 平均値(Mean)と中央値(Median)に加えて、 下位25%と上位25%を除いた残りの50%の平均値(Interquartile mean: IQC)の信頼区間を推定 ℎ𝑢𝑚𝑎𝑛 𝑛𝑜𝑟𝑚𝑎𝑙𝑖𝑧𝑒𝑑 𝑠𝑐𝑜𝑟𝑒 = 𝑠𝑐𝑜𝑟𝑒𝑎𝑔𝑒𝑛𝑡 − 𝑠𝑐𝑜𝑟𝑒𝑟𝑎𝑛𝑑𝑜𝑚 𝑠𝑐𝑜𝑟𝑒ℎ𝑢𝑚𝑎𝑛 − 𝑠𝑐𝑜𝑟𝑒𝑟𝑎𝑛𝑑𝑜𝑚 正規化スコアの定義: 𝑠𝑐𝑜𝑟𝑒𝑟𝑎𝑛𝑑𝑜𝑚 𝑠𝑐𝑜𝑟𝑒ℎ𝑢𝑚𝑎𝑛 文献[11]に従い正規化スコアを用いて評価を実施 [11] Agarwal, Rishabh, et al. "Deep reinforcement learning at the edge of the statistical precipice." 2021. Performance profileの図示: 正規化スコア以上の割合をグラフ化 :ランダム方策のスコア :人間プレイヤーのスコア
  12. 12. 信頼区間に関する結果 12 IRIS(提案手法)は平均値1.046、IQM値0.501を達成 → 26ゲーム中10ゲームで人間のプレイヤーより高い性能を発揮
  13. 13. Performance Profileに関する結果 13 IRIS(提案手法)はベースラインと同等以上の性能 正規化スコアを超える割合が0.5以下の場合は他手法よりも高性能 → Atari100kベンチマークで先読み検索を使用しない最先端技術であることを示唆 グラフの見方: 縦軸:正規化スコア以上の割合 横軸:正規化スコア 上にある曲線ほど優れた手法であることを意味 低性能 高性能 スコアが0以上の割合が100% スコアが1以上の割合が約30%(IRISが最も高性能)
  14. 14. 実験結果 14 Pong、Breakout、Boxingのような分布シフトの影響が小さいゲームで特に高性能を発揮
  15. 15. 実験結果 15 FrostbiteとKrullのようなサブゲームを段階的にクリアするゲームでは性能を発揮できない場合がある
  16. 16. FrostbiteとKrullの結果の考察 16 Frostbiteで低性能となった考察: 最初のレベルを終了するには、イグルー構築後に画面下部からイグルーに戻るという 稀でかつ一連の長い行動が必要 → 稀な事象は想像上で十分に経験できないため性能が低くなる Frostbite (左) と Krull (右)の3 つの連続レベル Krullで高性能となった考察: 次のステージへの移行が頻繁に行われる → 世界モデルがゲームの多様性をうまく反映できたため想像上でも十分に経験できた
  17. 17. 世界モデルの性能解析 17 想像の中で方策を学習するため世界モデルの精度が性能に直結する → 世界モデルの精度を生成画像から確認 Discrete autoencoder: ボール、プレイヤー、敵などの要素を正しく再構成しているか? Transformer: ゲームの重要な仕組み(報酬やエピソード終了)を正しく捉えているか? 性能評価のポイント: IRIS(提案手法)の世界モデルの性能解析を以下のケースで実施 KungFuMaster、Pong、BreakoutとGopher
  18. 18. KungFuMasterでの性能解析 18 各シミュレーションで様々な状況(敵の数など)を生成 青枠からプレイヤーに攻撃された敵は姿を消していることが確認できる → 世界モデルはゲームの重要な仕組みを捉えている 4つの軌跡例 シミュレーション開始点 (実環境の情報で初期化) 世界モデルの想像結果
  19. 19. Pongでの性能解析 19 世界モデルはボールの軌道と選手の動きを捉えている 青枠から勝者側のスコアボードが更新されていることが確認できる → ピクセル単位で高精度な予測を実現 世界モデルの 生成結果 → 実際の結果 → シミュレーション開始点 (実環境の情報で初期化)
  20. 20. BreakoutとGopherでの性能解析 20 黄枠:世界モデルが正の報酬を予測するフレーム 赤枠:エピソード終了のを予測しているフレーム 各行は実環境の情報で初期化し、残りの軌道を想像させた結果 ゲームの仕組みを高精度に予測 Gopher: 黄枠:穴をふさぐかモグラを倒すと報酬につながる 赤枠:モグラが人参に到達するとエピソードが終了 Breakout: 黄枠:レンガを壊すと報酬が得らる 赤枠:ボールを逃すとエピソードが終了
  21. 21. まとめ 21 IRIS (Imagination with auto-Regression over an Inner Speech): Discrete autoencoderとTransformerを組み合わせた世界モデルを提案 実験結果: Atari100kベンチマークで高性能を発揮 世界モデルはゲームの重要な仕組みを捉えて高精度な予測を実現 → 先読み検索を使用しない手法として最先端技術であることを示唆

×