TabNet: Attentive Interpretable
Tabular Learning
ニューラルネットワークと決定木のいいとこ取り?
機械学習コンペでよく使われる手法
今回紹介する内容
● TabNet: Attentive Interpretable Tabular Learning
○ 著者: Sercan O. Arik, Tomas Pfister
○ 投稿: https://arxiv.org/abs/1908.07442, 2019
● 論文の内容
○ 表形式のデータに適したDNNモデルTabNetを提案
○ DNNと決定木系の利点を組み合わせたモデル
○ 既存のデータセットに対してXGBoost, LightGBM, MLPなどと同等以上の性能
○ 特徴生成が不要であり、特徴重要度も算出できる
○ 教師なしの事前学習と組み合わせることでさらに精度が改善する
表形式データへのDNNの適用
● DNNはこれまで画像や自然言語、音声などの分野で多くの成果を残して
きた
● 実社会で最も馴染みのあるデータは表形式であるのにも関わらずこのよ
うな表形式データ向けのDNNは発展してこなかった
● 決定木系のアルゴリズムが中心に
● なぜか?
○ 解釈性が高い ← 実社会で求められる
○ 学習が速い
○ 勾配Boosting木の登場によりモデルの汎化性能が向上した
○ DNNの構造が表形式に適していない (スタック型の畳み込み層など)
DNNを用いる利点
● 大きなデータでの精度向上が期待できる
● end-to-endな学習が可能
○ 画像データなどと組み合わせてマルチモーダルな学習が可能に
○ 特徴量の生成が不要
○ ストリーミングデータを学習可能
○ 表現学習が可能となり、生成モデリングや半教師あり学習などが可能に
モデル概要
● Feature transformer
○ 情報のフィルタリング
○ 各stepごとにフィルターは異な
る
● Attentive transformer
○ 前stepの学習情報を元に使う
特徴を決める
● Mask
○ 元データをAttentive transformer
の情報を元にマスクする
● stepごとに特徴選択が行われる
● 1stepでやること → 前stepの学習に基づき次に使う特徴を決定
学習の流れ 重要な部分
2種類の解釈性
大域的解釈性
Feature importance
局所的解釈性(各ステップのMask)
local interpretability
各ステップまたは最終的に各レコードでどの特徴が重視されたか(白→注目した特徴)
X
事前学習(オプション)
● Encoder-Decoder モデル
● マスクされた特徴をその他の
特徴から補完できるように学習
● 特徴間の関係を学習することが本番学習
での最適な特徴選択につながる
このEncoderモデルを用いた転移学習
実験
● Forest Cover Type dataset
○ 分類タスク
○ 地理情報から森林中の木の種類を予測
● Store sales dataset
○ 回帰タスク
○ 店舗ごとの日時売上データから
各店舗の将来の売上を予測
XGBoost , LightGBM, MLP などと同等以上の性能を示す
実験
● 事前学習の効果
○ 事前学習なしに比べ精度向上
○ 学習の収束が早くなる
● Higgs Boson dataset
○ 分類: ある特定の素粒子衝突イベントを検出
Kaggleでの実例
● Mechanisms of Action (MoA) Prediction (2020/09~2020/12)
○ 上位10チームのほとんどがTabNetを利用
○ アンサンブルの一手法として使っていた
○ 高い精度を保ちつつ他のモデルと異なった特徴
を重要視できているのでは
1st place solution 2nd place solution
TabNetのまとめ
● DNNと決定木のいいとこ取りしたアルゴリズム
● 特徴量生成を必要としないend-to-endな学習
● 各決定ステップにおいてsequential attention(逐次注意)を用いて重要な特徴を
学習する
● 特徴の重要性とその組み合わせ方を視覚化する局所的解釈性と、学習済みモ
デルに対する各特徴の貢献度を定量化する大域的解釈性の2種類の解釈性を
可能にした
● マスクされた特徴を予測するための教師なし事前学習を使用することで、性
能が向上する
チャンネル紹介
● チャンネル名: 【経営xデータサイエンスx開発】西岡 賢一郎のチャンネル
● URL: https://www.youtube.com/channel/UCpiskjqLv1AJg64jFCQIyBg
● チャンネルの内容
○ 経営・データサイエンス・開発に関する情報を発信しています。
○ 例: アジャイル開発、データパイプライン構築、AIで使われるアルゴリズム4種類など
● noteでも情報発信しています → https://note.com/kenichiro

TabNetの論文紹介