SlideShare a Scribd company logo
Adversarial Training
to avoid overfitting
NBME top#2 Solution and Discussion
https://www.kaggle.com/competitions/nbme-score-
clinical-patient-notes/discussion/323085
Feedback top#1で活用されたが,NBMEでは
“Although its CV score was quite higher than the one I selected above, both its public LB score and private LB score were lower.
It seems that my way of doing pseudo labeling was better. It may be that being quite new to these techniques I didn't tune
them correctly. I will try them in future competitions for sure.”
とあるので汎用性についてはさらなる実装と議論が必要か。
Adversarial Training
Inputs Perturbation
into “Local” worst-case
Weights Perturbation
into “Global” worst-case
Gradient-based Adversary Not gradient-based
Need Labels Not need Labels
FGM, SiFT
VAT, TRADES,
SMART
MART
AWP
Adversarial Training
https://arxiv.org/abs/1412.6572 : Goodfellow IJ et al., 2015, ICLR 2015
 摂動を加えた入力の中でモデルにとってhigh confidenceに間違えるようなもの = Adversarial Examples
 Adversarial Examplesを作成しながらモデル精度を高める = Adversarial Training
 ランダムノイズを加える点では一般的なaugmentationと捉えられるが,その中でもよりadversarialなもの
Adversarial Example
自信をもって間違えている
Perturbation is:
𝜂 = 𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝑦
loss backwardの時点で入力xについても自動微分が得られる
lossを大きくする方向 (𝑠𝑖𝑔𝑛)に微小(𝜖)動かす
参考) https://ai-scholar.tech/articles/adversarial-perturbation/Earlystopping
目的関数に追加: 𝐽 𝜃, 𝑥, 𝑦 = 𝛼𝐽 𝜃, 𝑥, 𝑦 + 1 − 𝛼 𝐽 𝜃, 𝑥 + 𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝑦 , 𝑦
VAT; Virtual Adversarial Training
https://arxiv.org/pdf/1507.00677.pdf : Miyato T et al., 2016, ICLR 2016
https://arxiv.org/pdf/1704.03976v2.pdf : Miyato T et al., 2018
https://arxiv.org/pdf/1605.07725.pdf : Miyato T et al., 2021
 モデル分散の平滑化を目的とした正則化として働く
 当初のAdversarial Examples作成にはLabel (𝑦)情報が必要 (𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝒚 なので)だが,VATではPerturbationを与えた
ときにmodel outputがどれくらい動くか(LDS)を用いるのでLabel不要
 勾配情報を用いるとsignによってadversarial directionが決まるが,VATではLDSを探索することでdirectionを決めるた
め,’virtual‘ adversarial trainingという名前がついた
𝑥 𝑛
𝑖𝑛𝑝𝑢𝑡 𝑠𝑝𝑎𝑐𝑒
𝑝 𝑦|𝑥 𝑛
, 𝜃
model output
(e.g. classification
prediction)
なめらかな予測曲線
= 過学習しづらい (正しい学習)
過学習の傾向
観測されたデータ
0.9
ちょっとノイズが入ると予測が
大きく変動する
ちょっと入力を動かしたときの予測自体 (精度ではない)が変動
するかをKL Div.で定量したものをLocal Distributional
Smoothing (LDS)とした。
∆𝐾𝐿 𝑟, 𝑥 𝑛 , 𝜃 ≡ 𝐾𝐿 𝑝 𝑦|𝑥 𝑛 , 𝜃 ||𝑝 𝑦|𝑥 𝑛 + 𝑟, 𝜃
𝑟𝑣−𝑎𝑑𝑣
𝑛
≡ 𝑎𝑟𝑔 max
𝑟
∆𝐾𝐿 𝑟, 𝑥 𝑛
, 𝜃 ; 𝑟 2 ≤ 𝜖
𝐿𝐷𝑆 𝑥 𝑛
, 𝜃 ≡ −∆𝐾𝐿 𝑟𝑣−𝑎𝑑𝑣
𝑛
, 𝑥 𝑛
, 𝜃
https://github.com/tensorflow/models/tree/master/research/adversarial_text
LDSを正則化項として目的関数に追加
詰まるところ, 𝑟𝑣−𝑎𝑑𝑣
𝑛
を決めるのが大変
𝑟𝑣−𝑎𝑑𝑣
𝑛
≡ 𝑎𝑟𝑔 max
𝑟
∆𝐾𝐿 𝑟, 𝑥 𝑛
, 𝜃 ; 𝑟 2 ≤ 𝜖
これ自身も学習で求める
𝑖𝑛𝑝𝑢𝑡 に対して𝑟𝑎𝑛𝑑𝑜𝑚 𝑣𝑒𝑐𝑡𝑜𝑟 𝑑を初期化して以下SGDによって更新
𝑑 ← 𝛻𝑟𝐾𝐿 𝑟, 𝑥, 𝜃
𝑟=𝜉𝑑
𝑤ℎ𝑒𝑟𝑒 𝑣 =
𝑣
𝑣 2
⋯
⋯
普通にloss求める
LDS求める
⋯
一旦勾配計算とめて
⋯
⋯
adversarial_losses.py
train_classifier.py
KL Div.に対するdの勾配を求めて
𝑟𝑣−𝑎𝑑𝑣を得る
https://github.com/tensorflow/models/tree/master/research/adversarial_text
VAT-Pytorchのイメージ
mt_dnn/perturbation.py
TRADES
https://arxiv.org/abs/1901.08573 : Zhang H et al., 2019, ICML 2019
𝜌𝑇𝑅𝐴𝐷𝐸𝑆 𝑤 =
1
𝑛
𝑖=1
𝑛
𝐶𝐸 𝑓𝑤 𝑥𝑖 , 𝑦𝑖 + 𝛽 𝑚𝑎𝑥𝐾𝐿 𝑓𝑤 𝑥𝑖 ||𝑓𝑤 𝑥𝑖
′
 正直VATとの違いがあまり分からない
 コードがとても使いやすい
https://github.com/yaodongyu/TRADES
from trades import trades_loss
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# calculate robust loss - TRADES loss
loss = trades_loss(model=model,
x_natural=data,
y=target,
optimizer=optimizer,
step_size=args.step_size,
epsilon=args.epsilon,
perturb_steps=args.num_steps,
beta=args.beta,
distance='l_inf')
loss.backward()
optimizer.step()
MART
https://openreview.net/forum?id=rklOg6EFwS : Wang Y et al., 2019, ICLR 2020
https://github.com/YisenWang/MART
𝜌𝑀𝐴𝑅𝑇 𝑤 =
1
𝑛
𝑖=1
𝑛
𝐵𝐶𝐸 𝑓𝑤 𝑥𝑖
′
, 𝑦𝑖 + 𝜆 𝐾𝐿 𝑓𝑤 𝑥𝑖 ||𝑓𝑤 𝑥𝑖
′
∙ 1 − 𝑓𝑤 𝑥𝑖 𝑦𝑖
𝑤ℎ𝑒𝑟𝑒 𝑓𝑤 𝑥𝑖 𝑦𝑖
𝑑𝑒𝑛𝑜𝑡𝑒𝑠 𝑡ℎ𝑒 𝑦𝑖𝑡ℎ 𝑒𝑙𝑒𝑚𝑒𝑛𝑡 𝑜𝑓 𝑜𝑢𝑡𝑝𝑢𝑡 𝑣𝑒𝑐𝑡𝑜𝑟 𝑓𝑤 𝑥𝑖
𝑎𝑛𝑑 𝑥𝑖
′
𝑖𝑠 𝑓𝑟𝑜𝑚 𝑎𝑟𝑔 max
𝑥𝑖
′∈ℬ𝜖 𝑥𝑖
𝐶𝐸 𝑓𝑤 𝑥𝑖
′
, 𝑦𝑖
 正解例を当てることに焦点をあて,Adversarial Loss は負例(部)に関して足される
 adversarial examplesの生成には教師データが必要
from mart import mart_loss
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# calculate robust loss - MART loss
loss = mart_loss(model=model,
x_natural=data,
y=target,
optimizer=optimizer,
step_size=args.step_size,
epsilon=args.epsilon,
perturb_steps=args.num_steps,
beta=args.beta,
distance='l_inf')
loss.backward()
optimizer.step()
SMART
https://arxiv.org/abs/1911.03437 : Jiang H et al., 2021
1.正則化項の追加と2.Optimizationの工夫によって構成される
1. Smoothness-Inducing Adversarial Regularization: VATと同じ
2. Bregman Proximal Point Optimization: 学習パラメータ𝜃を更新前から大きく離れないよう更新
𝜃𝑡+1 = 𝑎𝑟𝑔 min
𝜃
ℒ 𝜃 + 𝜆𝑠𝑅𝑠 𝜃 + 𝜇𝐷𝐵𝑟𝑒𝑔 𝜃, 𝜃𝑡
≒VAT
𝜆𝑆, 𝜇: ℎ𝑦𝑝𝑒𝑟𝑝𝑎𝑟𝑎𝑚𝑒𝑡𝑒𝑟
𝑤ℎ𝑒𝑟𝑒 𝐷𝐵𝑟𝑒𝑔 𝜃, 𝜃𝑡 =
1
𝑛 𝑖=1
𝑛
∆𝐾𝐿 𝑓 𝑥𝑖; 𝜃 , 𝑓 𝑥𝑖; 𝜃𝑡 𝑓 𝑥; 𝜃 は入力xに対するoutput
https://github.com/namisan/mt-dnn
たぶんBregman Proximal Point Optimizationについてはgithubコードに実装されていない
AWP; Adversarial Weight Perturbation
https://arxiv.org/abs/2004.05884 : Wu D et al., 2020
 double-perturbation mechanism: both inputs and weights are adversarially perturbed
 weightの重みに摂動を加えた場合のモデル精度の不安定性(weight loss landscape)の低さが重要であると主張
⇒ 一般化に成功
𝑤𝑒𝑖𝑔ℎ𝑡 𝑙𝑜𝑠𝑠 𝑙𝑎𝑛𝑑𝑠𝑐𝑎𝑝𝑒
𝑔 𝛼 = 𝜌 𝑤 + 𝛼𝑑 =
1
𝑛
𝑖=1
𝑛
max
𝑥′𝑖−𝑥𝑖 𝑝≤𝜖
ℓ 𝑓𝑤+𝛼𝑑 𝑥𝑖
′
, 𝑦𝑖
𝑤ℎ𝑒𝑟𝑒 𝑑 𝑖𝑠 𝑠𝑎𝑚𝑝𝑙𝑒𝑑 𝑓𝑟𝑜𝑚 𝑎 𝐺𝑎𝑢𝑠𝑠𝑖𝑎𝑛 𝑑𝑖𝑠𝑡𝑟𝑖𝑏𝑢𝑡𝑖𝑜𝑛 𝑎𝑛𝑑 𝑓𝑖𝑙𝑡𝑒𝑟 𝑛𝑜𝑟𝑚𝑎𝑙𝑖𝑧𝑒𝑑 𝑏𝑦 𝑑𝑙,𝑗 ←
𝑑𝑙,𝑗
𝑑𝑙,𝑗 𝐹
𝑤𝑙,𝑗 𝐹
重みの摂動に対して安定
過学習の状態では重みの摂動に
対して不安定
https://github.com/csdongxian/AWP
gap
が小さいほど
Test Accuracy
は高い傾向
⇒ gapをLossに追加
 なぜweight perturbationが有効かの考察
• adversarial perturbation on inputsはそれぞれの入力についてモデルが不得意とするperturbationを与える
= “local” worst-case
• adversarial perturbation on weightsは全データに関して予測を(程よく)崩すようなperturbationを与える
= “global” worst-case
⇒ ともに助け合いながらRobust modelが学習される
min
𝑤
𝜌 𝑤 + 𝜌 𝑤 + 𝑣 − 𝜌 𝑤 → min
𝑤
𝜌 𝑤 + 𝑣 ただし𝜌 𝑤 は入力データに対するadversarial loss
より
min
𝑤
max
𝑣∈𝑉
1
𝑛
𝑖=1
𝑛
max
𝑥𝑖
′−𝑥𝑖 𝑝
≤𝜖
ℓ 𝑓𝑤+𝑣 𝑥𝑖
′
, 𝑦𝑖
このmaximizeは各batchについて計算されるので注意
batch-sizeは重要。
AWPは結果として大きさに関する
正則化としても機能している
AWP Code
https://github.com/namisan/mt-dnn では,at_AWPやtrades_AWPコードが公開されているので任意のモデルに応用できるはず
for batch_idx, (data, target) in enumerate(train_loader):
x_natural, target = data.to(device), target.to(device)
# craft adversarial examples
x_adv = perturb_input(model=model,
x_natural=x_natural,
step_size=step_size,
epsilon=epsilon,
perturb_steps=args.num_steps,
distance=args.norm)
model.train()
# calculate adversarial weight perturbation
if epoch >= args.awp_warmup:
awp = awp_adversary.calc_awp(inputs_adv=x_adv,
inputs_clean=x_natural,
targets=target,
beta=args.beta)
awp_adversary.perturb(awp)
optimizer.zero_grad()
logits_adv = model(x_adv)
loss_robust = F.kl_div(F.log_softmax(logits_adv, dim=1),
F.softmax(model(x_natural), dim=1),
reduction='batchmean')
# calculate natural loss and backprop
logits = model(x_natural)
loss_natural = F.cross_entropy(logits, target)
loss = loss_natural + args.beta * loss_robust
inputsに対するadversarial attack
weightsに対するadversarial attack
AWP Code
NBME top#1 Code
https://www.kaggle.com/code/wht1996/feedback-nn-train/notebook
 正直参考にした論文と結構異なるので混乱…
 inputに対するadversarial trainingはなし (たぶんpre-trainedだからだと思う…)
def attack_backward(self, x, y, attention_mask,epoch):
if (self.adv_lr == 0) or (epoch < self.start_epoch):
return None
self._save()
for i in range(self.adv_step):
self._attack_step()
with torch.cuda.amp.autocast():
adv_loss, tr_logits = self.model(input_ids=x, attention_mask=attention_mask, labels=y)
adv_loss = adv_loss.mean()
self.optimizer.zero_grad()
self.scaler.scale(adv_loss).backward()
self._restore()
def _attack_step(self):
e = 1e-6
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
norm1 = torch.norm(param.grad)
norm2 = torch.norm(param.data.detach())
if norm1 != 0 and not torch.isnan(norm1):
r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
param.data.add_(r_at)
param.data = torch.min(
torch.max(param.data, self.backup_eps[name][0]), self.backup_eps[name][1]
)
# param.data.clamp_(*self.backup_eps[name])
# Define AWP class in advance
awp = AWP(model,
optimizer,
adv_lr=args.adv_lr,
adv_eps=args.adv_eps,
start_epoch=args.num_train_st
eps/args.epochs,
scaler=scaler)
# during train....
# logits = model(inputs)
# loss = ....
# loss.backward()
awp.attack_backward(input_ids, labels,
attention_mask, step)
# optimizer.step()
𝜌 𝑤 + 𝑣 の𝑣が
𝑣 = 𝛻𝑤ℒ という感じ??
FGM; Fast Gradient Method
https://www.kaggle.com/c/tweet-sentiment-extraction/discussion/143764
 一番最初のAdversarial trainingのこと
 inputsに対するadversarial attackだが,NLPの場合embeddingに対してかかるのでweightsに対するadversarial attackの
ように記述する
 書き方からして先ほどのAWPはこれを真似たのだろう class FGM():
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name='word_embeddings'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='word_embeddings'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
fgm = FGM(model)
for batch_input, batch_label in data:
loss = model(batch_input, batch_label)
loss.backward()
# adversarial training
fgm.attack()
loss_adv = model(batch_input, batch_label)
loss_adv.backward()
fgm.restore()
optimizer.step()
model.zero_grad()
SiFT; Scale Invariant Fine-Tuning
https://github.com/microsoft/DeBERTa/tree/master/DeBERTa/sift
 FGMと同じ。embeddingについてGradient-base adversarial attackを行う

More Related Content

What's hot

[DL輪読会]相互情報量最大化による表現学習
[DL輪読会]相互情報量最大化による表現学習[DL輪読会]相互情報量最大化による表現学習
[DL輪読会]相互情報量最大化による表現学習
Deep Learning JP
 
[DL輪読会]GLIDE: Guided Language to Image Diffusion for Generation and Editing
[DL輪読会]GLIDE: Guided Language to Image Diffusion  for Generation and Editing[DL輪読会]GLIDE: Guided Language to Image Diffusion  for Generation and Editing
[DL輪読会]GLIDE: Guided Language to Image Diffusion for Generation and Editing
Deep Learning JP
 
深層学習の数理:カーネル法, スパース推定との接点
深層学習の数理:カーネル法, スパース推定との接点深層学習の数理:カーネル法, スパース推定との接点
深層学習の数理:カーネル法, スパース推定との接点
Taiji Suzuki
 
[DL輪読会]Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) 表形式デー...
[DL輪読会]Revisiting Deep Learning Models for Tabular Data  (NeurIPS 2021) 表形式デー...[DL輪読会]Revisiting Deep Learning Models for Tabular Data  (NeurIPS 2021) 表形式デー...
[DL輪読会]Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) 表形式デー...
Deep Learning JP
 
[DL輪読会]data2vec: A General Framework for Self-supervised Learning in Speech,...
[DL輪読会]data2vec: A General Framework for  Self-supervised Learning in Speech,...[DL輪読会]data2vec: A General Framework for  Self-supervised Learning in Speech,...
[DL輪読会]data2vec: A General Framework for Self-supervised Learning in Speech,...
Deep Learning JP
 
【DL輪読会】How Much Can CLIP Benefit Vision-and-Language Tasks?
【DL輪読会】How Much Can CLIP Benefit Vision-and-Language Tasks? 【DL輪読会】How Much Can CLIP Benefit Vision-and-Language Tasks?
【DL輪読会】How Much Can CLIP Benefit Vision-and-Language Tasks?
Deep Learning JP
 
Triplet Loss 徹底解説
Triplet Loss 徹底解説Triplet Loss 徹底解説
Triplet Loss 徹底解説
tancoro
 
Swin Transformer (ICCV'21 Best Paper) を完璧に理解する資料
Swin Transformer (ICCV'21 Best Paper) を完璧に理解する資料Swin Transformer (ICCV'21 Best Paper) を完璧に理解する資料
Swin Transformer (ICCV'21 Best Paper) を完璧に理解する資料
Yusuke Uchida
 
【メタサーベイ】数式ドリブン教師あり学習
【メタサーベイ】数式ドリブン教師あり学習【メタサーベイ】数式ドリブン教師あり学習
【メタサーベイ】数式ドリブン教師あり学習
cvpaper. challenge
 
劣モジュラ最適化と機械学習1章
劣モジュラ最適化と機械学習1章劣モジュラ最適化と機械学習1章
劣モジュラ最適化と機械学習1章
Hakky St
 
GAN(と強化学習との関係)
GAN(と強化学習との関係)GAN(と強化学習との関係)
GAN(と強化学習との関係)
Masahiro Suzuki
 
最適輸送の計算アルゴリズムの研究動向
最適輸送の計算アルゴリズムの研究動向最適輸送の計算アルゴリズムの研究動向
最適輸送の計算アルゴリズムの研究動向
ohken
 
Transformer メタサーベイ
Transformer メタサーベイTransformer メタサーベイ
Transformer メタサーベイ
cvpaper. challenge
 
Active Learning の基礎と最近の研究
Active Learning の基礎と最近の研究Active Learning の基礎と最近の研究
Active Learning の基礎と最近の研究
Fumihiko Takahashi
 
PRML学習者から入る深層生成モデル入門
PRML学習者から入る深層生成モデル入門PRML学習者から入る深層生成モデル入門
PRML学習者から入る深層生成モデル入門
tmtm otm
 
Optimizer入門&最新動向
Optimizer入門&最新動向Optimizer入門&最新動向
Optimizer入門&最新動向
Motokawa Tetsuya
 
深層学習の数理
深層学習の数理深層学習の数理
深層学習の数理
Taiji Suzuki
 
AHC-Lab M1勉強会 論文の読み方・書き方
AHC-Lab M1勉強会 論文の読み方・書き方AHC-Lab M1勉強会 論文の読み方・書き方
AHC-Lab M1勉強会 論文の読み方・書き方
Shinagawa Seitaro
 
[DL輪読会]ICLR2020の分布外検知速報
[DL輪読会]ICLR2020の分布外検知速報[DL輪読会]ICLR2020の分布外検知速報
[DL輪読会]ICLR2020の分布外検知速報
Deep Learning JP
 
[Ridge-i 論文よみかい] Wasserstein auto encoder
[Ridge-i 論文よみかい] Wasserstein auto encoder[Ridge-i 論文よみかい] Wasserstein auto encoder
[Ridge-i 論文よみかい] Wasserstein auto encoder
Masanari Kimura
 

What's hot (20)

[DL輪読会]相互情報量最大化による表現学習
[DL輪読会]相互情報量最大化による表現学習[DL輪読会]相互情報量最大化による表現学習
[DL輪読会]相互情報量最大化による表現学習
 
[DL輪読会]GLIDE: Guided Language to Image Diffusion for Generation and Editing
[DL輪読会]GLIDE: Guided Language to Image Diffusion  for Generation and Editing[DL輪読会]GLIDE: Guided Language to Image Diffusion  for Generation and Editing
[DL輪読会]GLIDE: Guided Language to Image Diffusion for Generation and Editing
 
深層学習の数理:カーネル法, スパース推定との接点
深層学習の数理:カーネル法, スパース推定との接点深層学習の数理:カーネル法, スパース推定との接点
深層学習の数理:カーネル法, スパース推定との接点
 
[DL輪読会]Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) 表形式デー...
[DL輪読会]Revisiting Deep Learning Models for Tabular Data  (NeurIPS 2021) 表形式デー...[DL輪読会]Revisiting Deep Learning Models for Tabular Data  (NeurIPS 2021) 表形式デー...
[DL輪読会]Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) 表形式デー...
 
[DL輪読会]data2vec: A General Framework for Self-supervised Learning in Speech,...
[DL輪読会]data2vec: A General Framework for  Self-supervised Learning in Speech,...[DL輪読会]data2vec: A General Framework for  Self-supervised Learning in Speech,...
[DL輪読会]data2vec: A General Framework for Self-supervised Learning in Speech,...
 
【DL輪読会】How Much Can CLIP Benefit Vision-and-Language Tasks?
【DL輪読会】How Much Can CLIP Benefit Vision-and-Language Tasks? 【DL輪読会】How Much Can CLIP Benefit Vision-and-Language Tasks?
【DL輪読会】How Much Can CLIP Benefit Vision-and-Language Tasks?
 
Triplet Loss 徹底解説
Triplet Loss 徹底解説Triplet Loss 徹底解説
Triplet Loss 徹底解説
 
Swin Transformer (ICCV'21 Best Paper) を完璧に理解する資料
Swin Transformer (ICCV'21 Best Paper) を完璧に理解する資料Swin Transformer (ICCV'21 Best Paper) を完璧に理解する資料
Swin Transformer (ICCV'21 Best Paper) を完璧に理解する資料
 
【メタサーベイ】数式ドリブン教師あり学習
【メタサーベイ】数式ドリブン教師あり学習【メタサーベイ】数式ドリブン教師あり学習
【メタサーベイ】数式ドリブン教師あり学習
 
劣モジュラ最適化と機械学習1章
劣モジュラ最適化と機械学習1章劣モジュラ最適化と機械学習1章
劣モジュラ最適化と機械学習1章
 
GAN(と強化学習との関係)
GAN(と強化学習との関係)GAN(と強化学習との関係)
GAN(と強化学習との関係)
 
最適輸送の計算アルゴリズムの研究動向
最適輸送の計算アルゴリズムの研究動向最適輸送の計算アルゴリズムの研究動向
最適輸送の計算アルゴリズムの研究動向
 
Transformer メタサーベイ
Transformer メタサーベイTransformer メタサーベイ
Transformer メタサーベイ
 
Active Learning の基礎と最近の研究
Active Learning の基礎と最近の研究Active Learning の基礎と最近の研究
Active Learning の基礎と最近の研究
 
PRML学習者から入る深層生成モデル入門
PRML学習者から入る深層生成モデル入門PRML学習者から入る深層生成モデル入門
PRML学習者から入る深層生成モデル入門
 
Optimizer入門&最新動向
Optimizer入門&最新動向Optimizer入門&最新動向
Optimizer入門&最新動向
 
深層学習の数理
深層学習の数理深層学習の数理
深層学習の数理
 
AHC-Lab M1勉強会 論文の読み方・書き方
AHC-Lab M1勉強会 論文の読み方・書き方AHC-Lab M1勉強会 論文の読み方・書き方
AHC-Lab M1勉強会 論文の読み方・書き方
 
[DL輪読会]ICLR2020の分布外検知速報
[DL輪読会]ICLR2020の分布外検知速報[DL輪読会]ICLR2020の分布外検知速報
[DL輪読会]ICLR2020の分布外検知速報
 
[Ridge-i 論文よみかい] Wasserstein auto encoder
[Ridge-i 論文よみかい] Wasserstein auto encoder[Ridge-i 論文よみかい] Wasserstein auto encoder
[Ridge-i 論文よみかい] Wasserstein auto encoder
 

Similar to adversarial training.pptx

ADVENTURE_Solidの概要
ADVENTURE_Solidの概要ADVENTURE_Solidの概要
ADVENTURE_Solidの概要
ADVENTURE Project
 
Azure Machine Learning Services 概要 - 2019年3月版
Azure Machine Learning Services 概要 - 2019年3月版Azure Machine Learning Services 概要 - 2019年3月版
Azure Machine Learning Services 概要 - 2019年3月版
Daiyu Hatakeyama
 
20181212 - PGconf.ASIA - LT
20181212 - PGconf.ASIA - LT20181212 - PGconf.ASIA - LT
20181212 - PGconf.ASIA - LT
Kohei KaiGai
 
Wandb Monthly Meetup August 2023.pdf
Wandb Monthly Meetup August 2023.pdfWandb Monthly Meetup August 2023.pdf
Wandb Monthly Meetup August 2023.pdf
Yuya Yamamoto
 
第1回 Jubatusハンズオン
第1回 Jubatusハンズオン第1回 Jubatusハンズオン
第1回 JubatusハンズオンJubatusOfficial
 
第1回 Jubatusハンズオン
第1回 Jubatusハンズオン第1回 Jubatusハンズオン
第1回 JubatusハンズオンYuya Unno
 
20170127 JAWS HPC-UG#8
20170127 JAWS HPC-UG#820170127 JAWS HPC-UG#8
20170127 JAWS HPC-UG#8
Kohei KaiGai
 
Okinawa.rb 第2回勉強会
Okinawa.rb 第2回勉強会Okinawa.rb 第2回勉強会
Okinawa.rb 第2回勉強会
Naoki Takaesu
 
Learning Template Library Design using Boost.Geomtry
Learning Template Library Design using Boost.GeomtryLearning Template Library Design using Boost.Geomtry
Learning Template Library Design using Boost.GeomtryAkira Takahashi
 
Asakusa Enterprise Batch Processing Framework for Hadoop
Asakusa Enterprise Batch Processing Framework for HadoopAsakusa Enterprise Batch Processing Framework for Hadoop
Asakusa Enterprise Batch Processing Framework for Hadoop
Takashi Kambayashi
 
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Daiyu Hatakeyama
 
新しい並列for構文のご提案
新しい並列for構文のご提案新しい並列for構文のご提案
新しい並列for構文のご提案
yohhoy
 
Java ORマッパー選定のポイント #jsug
Java ORマッパー選定のポイント #jsugJava ORマッパー選定のポイント #jsug
Java ORマッパー選定のポイント #jsug
Masatoshi Tada
 
Try_to_writecode_practicaltest #atest_hack
Try_to_writecode_practicaltest #atest_hackTry_to_writecode_practicaltest #atest_hack
Try_to_writecode_practicaltest #atest_hack
kimukou_26 Kimukou
 
プログラミングで言いたい聞きたいこと集
プログラミングで言いたい聞きたいこと集プログラミングで言いたい聞きたいこと集
プログラミングで言いたい聞きたいこと集tecopark
 
プログラミングで言いたいこと聞きたいこと集
プログラミングで言いたいこと聞きたいこと集プログラミングで言いたいこと聞きたいこと集
プログラミングで言いたいこと聞きたいこと集tecopark
 
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
Computational Materials Science Initiative
 
エンジニアのための機械学習の基礎
エンジニアのための機械学習の基礎エンジニアのための機械学習の基礎
エンジニアのための機械学習の基礎
Daiyu Hatakeyama
 
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
de:code 2017
 
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdfウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
Yuya Yamamoto
 

Similar to adversarial training.pptx (20)

ADVENTURE_Solidの概要
ADVENTURE_Solidの概要ADVENTURE_Solidの概要
ADVENTURE_Solidの概要
 
Azure Machine Learning Services 概要 - 2019年3月版
Azure Machine Learning Services 概要 - 2019年3月版Azure Machine Learning Services 概要 - 2019年3月版
Azure Machine Learning Services 概要 - 2019年3月版
 
20181212 - PGconf.ASIA - LT
20181212 - PGconf.ASIA - LT20181212 - PGconf.ASIA - LT
20181212 - PGconf.ASIA - LT
 
Wandb Monthly Meetup August 2023.pdf
Wandb Monthly Meetup August 2023.pdfWandb Monthly Meetup August 2023.pdf
Wandb Monthly Meetup August 2023.pdf
 
第1回 Jubatusハンズオン
第1回 Jubatusハンズオン第1回 Jubatusハンズオン
第1回 Jubatusハンズオン
 
第1回 Jubatusハンズオン
第1回 Jubatusハンズオン第1回 Jubatusハンズオン
第1回 Jubatusハンズオン
 
20170127 JAWS HPC-UG#8
20170127 JAWS HPC-UG#820170127 JAWS HPC-UG#8
20170127 JAWS HPC-UG#8
 
Okinawa.rb 第2回勉強会
Okinawa.rb 第2回勉強会Okinawa.rb 第2回勉強会
Okinawa.rb 第2回勉強会
 
Learning Template Library Design using Boost.Geomtry
Learning Template Library Design using Boost.GeomtryLearning Template Library Design using Boost.Geomtry
Learning Template Library Design using Boost.Geomtry
 
Asakusa Enterprise Batch Processing Framework for Hadoop
Asakusa Enterprise Batch Processing Framework for HadoopAsakusa Enterprise Batch Processing Framework for Hadoop
Asakusa Enterprise Batch Processing Framework for Hadoop
 
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
 
新しい並列for構文のご提案
新しい並列for構文のご提案新しい並列for構文のご提案
新しい並列for構文のご提案
 
Java ORマッパー選定のポイント #jsug
Java ORマッパー選定のポイント #jsugJava ORマッパー選定のポイント #jsug
Java ORマッパー選定のポイント #jsug
 
Try_to_writecode_practicaltest #atest_hack
Try_to_writecode_practicaltest #atest_hackTry_to_writecode_practicaltest #atest_hack
Try_to_writecode_practicaltest #atest_hack
 
プログラミングで言いたい聞きたいこと集
プログラミングで言いたい聞きたいこと集プログラミングで言いたい聞きたいこと集
プログラミングで言いたい聞きたいこと集
 
プログラミングで言いたいこと聞きたいこと集
プログラミングで言いたいこと聞きたいこと集プログラミングで言いたいこと聞きたいこと集
プログラミングで言いたいこと聞きたいこと集
 
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
 
エンジニアのための機械学習の基礎
エンジニアのための機械学習の基礎エンジニアのための機械学習の基礎
エンジニアのための機械学習の基礎
 
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
 
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdfウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
 

adversarial training.pptx

  • 1. Adversarial Training to avoid overfitting NBME top#2 Solution and Discussion https://www.kaggle.com/competitions/nbme-score- clinical-patient-notes/discussion/323085 Feedback top#1で活用されたが,NBMEでは “Although its CV score was quite higher than the one I selected above, both its public LB score and private LB score were lower. It seems that my way of doing pseudo labeling was better. It may be that being quite new to these techniques I didn't tune them correctly. I will try them in future competitions for sure.” とあるので汎用性についてはさらなる実装と議論が必要か。
  • 2. Adversarial Training Inputs Perturbation into “Local” worst-case Weights Perturbation into “Global” worst-case Gradient-based Adversary Not gradient-based Need Labels Not need Labels FGM, SiFT VAT, TRADES, SMART MART AWP
  • 3. Adversarial Training https://arxiv.org/abs/1412.6572 : Goodfellow IJ et al., 2015, ICLR 2015  摂動を加えた入力の中でモデルにとってhigh confidenceに間違えるようなもの = Adversarial Examples  Adversarial Examplesを作成しながらモデル精度を高める = Adversarial Training  ランダムノイズを加える点では一般的なaugmentationと捉えられるが,その中でもよりadversarialなもの Adversarial Example 自信をもって間違えている Perturbation is: 𝜂 = 𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝑦 loss backwardの時点で入力xについても自動微分が得られる lossを大きくする方向 (𝑠𝑖𝑔𝑛)に微小(𝜖)動かす 参考) https://ai-scholar.tech/articles/adversarial-perturbation/Earlystopping 目的関数に追加: 𝐽 𝜃, 𝑥, 𝑦 = 𝛼𝐽 𝜃, 𝑥, 𝑦 + 1 − 𝛼 𝐽 𝜃, 𝑥 + 𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝑦 , 𝑦
  • 4. VAT; Virtual Adversarial Training https://arxiv.org/pdf/1507.00677.pdf : Miyato T et al., 2016, ICLR 2016 https://arxiv.org/pdf/1704.03976v2.pdf : Miyato T et al., 2018 https://arxiv.org/pdf/1605.07725.pdf : Miyato T et al., 2021  モデル分散の平滑化を目的とした正則化として働く  当初のAdversarial Examples作成にはLabel (𝑦)情報が必要 (𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝒚 なので)だが,VATではPerturbationを与えた ときにmodel outputがどれくらい動くか(LDS)を用いるのでLabel不要  勾配情報を用いるとsignによってadversarial directionが決まるが,VATではLDSを探索することでdirectionを決めるた め,’virtual‘ adversarial trainingという名前がついた 𝑥 𝑛 𝑖𝑛𝑝𝑢𝑡 𝑠𝑝𝑎𝑐𝑒 𝑝 𝑦|𝑥 𝑛 , 𝜃 model output (e.g. classification prediction) なめらかな予測曲線 = 過学習しづらい (正しい学習) 過学習の傾向 観測されたデータ 0.9 ちょっとノイズが入ると予測が 大きく変動する ちょっと入力を動かしたときの予測自体 (精度ではない)が変動 するかをKL Div.で定量したものをLocal Distributional Smoothing (LDS)とした。 ∆𝐾𝐿 𝑟, 𝑥 𝑛 , 𝜃 ≡ 𝐾𝐿 𝑝 𝑦|𝑥 𝑛 , 𝜃 ||𝑝 𝑦|𝑥 𝑛 + 𝑟, 𝜃 𝑟𝑣−𝑎𝑑𝑣 𝑛 ≡ 𝑎𝑟𝑔 max 𝑟 ∆𝐾𝐿 𝑟, 𝑥 𝑛 , 𝜃 ; 𝑟 2 ≤ 𝜖 𝐿𝐷𝑆 𝑥 𝑛 , 𝜃 ≡ −∆𝐾𝐿 𝑟𝑣−𝑎𝑑𝑣 𝑛 , 𝑥 𝑛 , 𝜃 https://github.com/tensorflow/models/tree/master/research/adversarial_text LDSを正則化項として目的関数に追加
  • 5. 詰まるところ, 𝑟𝑣−𝑎𝑑𝑣 𝑛 を決めるのが大変 𝑟𝑣−𝑎𝑑𝑣 𝑛 ≡ 𝑎𝑟𝑔 max 𝑟 ∆𝐾𝐿 𝑟, 𝑥 𝑛 , 𝜃 ; 𝑟 2 ≤ 𝜖 これ自身も学習で求める 𝑖𝑛𝑝𝑢𝑡 に対して𝑟𝑎𝑛𝑑𝑜𝑚 𝑣𝑒𝑐𝑡𝑜𝑟 𝑑を初期化して以下SGDによって更新 𝑑 ← 𝛻𝑟𝐾𝐿 𝑟, 𝑥, 𝜃 𝑟=𝜉𝑑 𝑤ℎ𝑒𝑟𝑒 𝑣 = 𝑣 𝑣 2 ⋯ ⋯ 普通にloss求める LDS求める ⋯ 一旦勾配計算とめて ⋯ ⋯ adversarial_losses.py train_classifier.py KL Div.に対するdの勾配を求めて 𝑟𝑣−𝑎𝑑𝑣を得る https://github.com/tensorflow/models/tree/master/research/adversarial_text
  • 7. TRADES https://arxiv.org/abs/1901.08573 : Zhang H et al., 2019, ICML 2019 𝜌𝑇𝑅𝐴𝐷𝐸𝑆 𝑤 = 1 𝑛 𝑖=1 𝑛 𝐶𝐸 𝑓𝑤 𝑥𝑖 , 𝑦𝑖 + 𝛽 𝑚𝑎𝑥𝐾𝐿 𝑓𝑤 𝑥𝑖 ||𝑓𝑤 𝑥𝑖 ′  正直VATとの違いがあまり分からない  コードがとても使いやすい https://github.com/yaodongyu/TRADES from trades import trades_loss def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() # calculate robust loss - TRADES loss loss = trades_loss(model=model, x_natural=data, y=target, optimizer=optimizer, step_size=args.step_size, epsilon=args.epsilon, perturb_steps=args.num_steps, beta=args.beta, distance='l_inf') loss.backward() optimizer.step()
  • 8. MART https://openreview.net/forum?id=rklOg6EFwS : Wang Y et al., 2019, ICLR 2020 https://github.com/YisenWang/MART 𝜌𝑀𝐴𝑅𝑇 𝑤 = 1 𝑛 𝑖=1 𝑛 𝐵𝐶𝐸 𝑓𝑤 𝑥𝑖 ′ , 𝑦𝑖 + 𝜆 𝐾𝐿 𝑓𝑤 𝑥𝑖 ||𝑓𝑤 𝑥𝑖 ′ ∙ 1 − 𝑓𝑤 𝑥𝑖 𝑦𝑖 𝑤ℎ𝑒𝑟𝑒 𝑓𝑤 𝑥𝑖 𝑦𝑖 𝑑𝑒𝑛𝑜𝑡𝑒𝑠 𝑡ℎ𝑒 𝑦𝑖𝑡ℎ 𝑒𝑙𝑒𝑚𝑒𝑛𝑡 𝑜𝑓 𝑜𝑢𝑡𝑝𝑢𝑡 𝑣𝑒𝑐𝑡𝑜𝑟 𝑓𝑤 𝑥𝑖 𝑎𝑛𝑑 𝑥𝑖 ′ 𝑖𝑠 𝑓𝑟𝑜𝑚 𝑎𝑟𝑔 max 𝑥𝑖 ′∈ℬ𝜖 𝑥𝑖 𝐶𝐸 𝑓𝑤 𝑥𝑖 ′ , 𝑦𝑖  正解例を当てることに焦点をあて,Adversarial Loss は負例(部)に関して足される  adversarial examplesの生成には教師データが必要 from mart import mart_loss def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() # calculate robust loss - MART loss loss = mart_loss(model=model, x_natural=data, y=target, optimizer=optimizer, step_size=args.step_size, epsilon=args.epsilon, perturb_steps=args.num_steps, beta=args.beta, distance='l_inf') loss.backward() optimizer.step()
  • 9. SMART https://arxiv.org/abs/1911.03437 : Jiang H et al., 2021 1.正則化項の追加と2.Optimizationの工夫によって構成される 1. Smoothness-Inducing Adversarial Regularization: VATと同じ 2. Bregman Proximal Point Optimization: 学習パラメータ𝜃を更新前から大きく離れないよう更新 𝜃𝑡+1 = 𝑎𝑟𝑔 min 𝜃 ℒ 𝜃 + 𝜆𝑠𝑅𝑠 𝜃 + 𝜇𝐷𝐵𝑟𝑒𝑔 𝜃, 𝜃𝑡 ≒VAT 𝜆𝑆, 𝜇: ℎ𝑦𝑝𝑒𝑟𝑝𝑎𝑟𝑎𝑚𝑒𝑡𝑒𝑟 𝑤ℎ𝑒𝑟𝑒 𝐷𝐵𝑟𝑒𝑔 𝜃, 𝜃𝑡 = 1 𝑛 𝑖=1 𝑛 ∆𝐾𝐿 𝑓 𝑥𝑖; 𝜃 , 𝑓 𝑥𝑖; 𝜃𝑡 𝑓 𝑥; 𝜃 は入力xに対するoutput https://github.com/namisan/mt-dnn たぶんBregman Proximal Point Optimizationについてはgithubコードに実装されていない
  • 10. AWP; Adversarial Weight Perturbation https://arxiv.org/abs/2004.05884 : Wu D et al., 2020  double-perturbation mechanism: both inputs and weights are adversarially perturbed  weightの重みに摂動を加えた場合のモデル精度の不安定性(weight loss landscape)の低さが重要であると主張 ⇒ 一般化に成功 𝑤𝑒𝑖𝑔ℎ𝑡 𝑙𝑜𝑠𝑠 𝑙𝑎𝑛𝑑𝑠𝑐𝑎𝑝𝑒 𝑔 𝛼 = 𝜌 𝑤 + 𝛼𝑑 = 1 𝑛 𝑖=1 𝑛 max 𝑥′𝑖−𝑥𝑖 𝑝≤𝜖 ℓ 𝑓𝑤+𝛼𝑑 𝑥𝑖 ′ , 𝑦𝑖 𝑤ℎ𝑒𝑟𝑒 𝑑 𝑖𝑠 𝑠𝑎𝑚𝑝𝑙𝑒𝑑 𝑓𝑟𝑜𝑚 𝑎 𝐺𝑎𝑢𝑠𝑠𝑖𝑎𝑛 𝑑𝑖𝑠𝑡𝑟𝑖𝑏𝑢𝑡𝑖𝑜𝑛 𝑎𝑛𝑑 𝑓𝑖𝑙𝑡𝑒𝑟 𝑛𝑜𝑟𝑚𝑎𝑙𝑖𝑧𝑒𝑑 𝑏𝑦 𝑑𝑙,𝑗 ← 𝑑𝑙,𝑗 𝑑𝑙,𝑗 𝐹 𝑤𝑙,𝑗 𝐹 重みの摂動に対して安定 過学習の状態では重みの摂動に 対して不安定 https://github.com/csdongxian/AWP gap が小さいほど Test Accuracy は高い傾向 ⇒ gapをLossに追加
  • 11.  なぜweight perturbationが有効かの考察 • adversarial perturbation on inputsはそれぞれの入力についてモデルが不得意とするperturbationを与える = “local” worst-case • adversarial perturbation on weightsは全データに関して予測を(程よく)崩すようなperturbationを与える = “global” worst-case ⇒ ともに助け合いながらRobust modelが学習される min 𝑤 𝜌 𝑤 + 𝜌 𝑤 + 𝑣 − 𝜌 𝑤 → min 𝑤 𝜌 𝑤 + 𝑣 ただし𝜌 𝑤 は入力データに対するadversarial loss より min 𝑤 max 𝑣∈𝑉 1 𝑛 𝑖=1 𝑛 max 𝑥𝑖 ′−𝑥𝑖 𝑝 ≤𝜖 ℓ 𝑓𝑤+𝑣 𝑥𝑖 ′ , 𝑦𝑖 このmaximizeは各batchについて計算されるので注意 batch-sizeは重要。 AWPは結果として大きさに関する 正則化としても機能している
  • 12. AWP Code https://github.com/namisan/mt-dnn では,at_AWPやtrades_AWPコードが公開されているので任意のモデルに応用できるはず for batch_idx, (data, target) in enumerate(train_loader): x_natural, target = data.to(device), target.to(device) # craft adversarial examples x_adv = perturb_input(model=model, x_natural=x_natural, step_size=step_size, epsilon=epsilon, perturb_steps=args.num_steps, distance=args.norm) model.train() # calculate adversarial weight perturbation if epoch >= args.awp_warmup: awp = awp_adversary.calc_awp(inputs_adv=x_adv, inputs_clean=x_natural, targets=target, beta=args.beta) awp_adversary.perturb(awp) optimizer.zero_grad() logits_adv = model(x_adv) loss_robust = F.kl_div(F.log_softmax(logits_adv, dim=1), F.softmax(model(x_natural), dim=1), reduction='batchmean') # calculate natural loss and backprop logits = model(x_natural) loss_natural = F.cross_entropy(logits, target) loss = loss_natural + args.beta * loss_robust inputsに対するadversarial attack weightsに対するadversarial attack
  • 13. AWP Code NBME top#1 Code https://www.kaggle.com/code/wht1996/feedback-nn-train/notebook  正直参考にした論文と結構異なるので混乱…  inputに対するadversarial trainingはなし (たぶんpre-trainedだからだと思う…) def attack_backward(self, x, y, attention_mask,epoch): if (self.adv_lr == 0) or (epoch < self.start_epoch): return None self._save() for i in range(self.adv_step): self._attack_step() with torch.cuda.amp.autocast(): adv_loss, tr_logits = self.model(input_ids=x, attention_mask=attention_mask, labels=y) adv_loss = adv_loss.mean() self.optimizer.zero_grad() self.scaler.scale(adv_loss).backward() self._restore() def _attack_step(self): e = 1e-6 for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None and self.adv_param in name: norm1 = torch.norm(param.grad) norm2 = torch.norm(param.data.detach()) if norm1 != 0 and not torch.isnan(norm1): r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e) param.data.add_(r_at) param.data = torch.min( torch.max(param.data, self.backup_eps[name][0]), self.backup_eps[name][1] ) # param.data.clamp_(*self.backup_eps[name]) # Define AWP class in advance awp = AWP(model, optimizer, adv_lr=args.adv_lr, adv_eps=args.adv_eps, start_epoch=args.num_train_st eps/args.epochs, scaler=scaler) # during train.... # logits = model(inputs) # loss = .... # loss.backward() awp.attack_backward(input_ids, labels, attention_mask, step) # optimizer.step() 𝜌 𝑤 + 𝑣 の𝑣が 𝑣 = 𝛻𝑤ℒ という感じ??
  • 14. FGM; Fast Gradient Method https://www.kaggle.com/c/tweet-sentiment-extraction/discussion/143764  一番最初のAdversarial trainingのこと  inputsに対するadversarial attackだが,NLPの場合embeddingに対してかかるのでweightsに対するadversarial attackの ように記述する  書き方からして先ほどのAWPはこれを真似たのだろう class FGM(): def __init__(self, model): self.model = model self.backup = {} def attack(self, epsilon=1., emb_name='word_embeddings'): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: self.backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm != 0: r_at = epsilon * param.grad / norm param.data.add_(r_at) def restore(self, emb_name='word_embeddings'): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: assert name in self.backup param.data = self.backup[name] self.backup = {} fgm = FGM(model) for batch_input, batch_label in data: loss = model(batch_input, batch_label) loss.backward() # adversarial training fgm.attack() loss_adv = model(batch_input, batch_label) loss_adv.backward() fgm.restore() optimizer.step() model.zero_grad() SiFT; Scale Invariant Fine-Tuning https://github.com/microsoft/DeBERTa/tree/master/DeBERTa/sift  FGMと同じ。embeddingについてGradient-base adversarial attackを行う