Successfully reported this slideshow.
We use your LinkedIn profile and activity data to personalize ads and to show you more relevant ads. You can change your ad preferences anytime.

Chainer の Trainer 解説と NStepLSTM について

11,578 views

Published on

レトリバセミナー 2017/03/15
Movie: https://www.youtube.com/watch?v=ok_bvPKAEaM

Published in: Engineering
  • DOWNLOAD FULL BOOKS, INTO AVAILABLE FORMAT ......................................................................................................................... ......................................................................................................................... 1.DOWNLOAD FULL. PDF EBOOK here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. EPUB Ebook here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. doc Ebook here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. PDF EBOOK here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. EPUB Ebook here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. doc Ebook here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... ......................................................................................................................... ......................................................................................................................... .............. Browse by Genre Available eBooks ......................................................................................................................... Art, Biography, Business, Chick Lit, Children's, Christian, Classics, Comics, Contemporary, Cookbooks, Crime, Ebooks, Fantasy, Fiction, Graphic Novels, Historical Fiction, History, Horror, Humor And Comedy, Manga, Memoir, Music, Mystery, Non Fiction, Paranormal, Philosophy, Poetry, Psychology, Religion, Romance, Science, Science Fiction, Self Help, Suspense, Spirituality, Sports, Thriller, Travel, Young Adult,
       Reply 
    Are you sure you want to  Yes  No
    Your message goes here
  • DOWNLOAD FULL BOOKS, INTO AVAILABLE FORMAT ......................................................................................................................... ......................................................................................................................... 1.DOWNLOAD FULL. PDF EBOOK here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. EPUB Ebook here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. doc Ebook here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. PDF EBOOK here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. EPUB Ebook here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... 1.DOWNLOAD FULL. doc Ebook here { https://tinyurl.com/y6a5rkg5 } ......................................................................................................................... ......................................................................................................................... ......................................................................................................................... .............. Browse by Genre Available eBooks ......................................................................................................................... Art, Biography, Business, Chick Lit, Children's, Christian, Classics, Comics, Contemporary, Cookbooks, Crime, Ebooks, Fantasy, Fiction, Graphic Novels, Historical Fiction, History, Horror, Humor And Comedy, Manga, Memoir, Music, Mystery, Non Fiction, Paranormal, Philosophy, Poetry, Psychology, Religion, Romance, Science, Science Fiction, Self Help, Suspense, Spirituality, Sports, Thriller, Travel, Young Adult,
       Reply 
    Are you sure you want to  Yes  No
    Your message goes here
  • Hello! High Quality And Affordable Essays For You. Starting at $4.99 per page - Check our website! https://vk.cc/82gJD2
       Reply 
    Are you sure you want to  Yes  No
    Your message goes here

Chainer の Trainer 解説と NStepLSTM について

  1. 1. Chainer の Trainer 解説と NStepLSTM について 株式会社レトリバ © 2017 Retrieva, Inc.
  2. 2. ⾃⼰紹介 • ⽩⼟慧(シラツチ ケイ) • 株式会社レトリバ • 2016年4⽉⼊社 • Ruby on Rails / JavaScript • フロントエンド側の⼈間 • ⼤学時代は複雑ネットワーク科学の研究 • Chainer ⼊⾨中 © 2017 Retrieva, Inc. 2
  3. 3. アジェンダ • 第1部 Chainer における Trainer の解説 • 第2部 NStepLSTM との格闘 © 2017 Retrieva, Inc. 3
  4. 4. アンケート • Chainer を使っている⽅ • Chainer の Trainer を使っている⽅ • LSTM を使っている⽅ • NStepLSTM を使っている⽅ • NStepLSTM と Trainer を使っている⽅ © 2017 Retrieva, Inc. 4
  5. 5. 第1部 Chainer における Trainer © 2017 Retrieva, Inc. 5
  6. 6. Chainer における Trainer • Chainer 1.11.0 から導⼊された学習フレームワーク • batchの取り出し、forward/backward が抽象化されている • 進捗表⽰、モデルのスナップショットなど • Trainer 後から⼊⾨した⼈(私も)は、MNIST のサンプルが Trainerで抽象化されていて、何が起きているのかわからない • 以前から Chainer を使っている⼈は、Trainer なしで動かして いることが多い © 2017 Retrieva, Inc. 6
  7. 7. Trainer 全体図 (train_mnist.py) © 2017 Retrieva, Inc. 7 Trainer Updater (StandardUpdater) Iterator Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  8. 8. Trainer • Trainer フレームワークの⼤元 • 渡された Updater、(必要があれば) Evaluator を実⾏する • グラフのダンプ、スナップショット、レポーティング、進捗表 ⽰などを、Extension として実⾏できる © 2017 Retrieva, Inc. 8
  9. 9. Trainer • 指定した epoch 数になるまで、Updater の update() を呼ぶ © 2017 Retrieva, Inc. 9 # examples/mnist/train_mnist.py trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) # chainer/training/trainer.py class Trainer(object): def run(self): update = self.updater.update # main training loop try: while not stop_trigger(self): self.observation = {} with reporter.scope(self.observation): update()
  10. 10. Updater © 2017 Retrieva, Inc. 10 Trainer Updater (StandardUpdater) Iterator Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  11. 11. Updater • ⼊⼒を逐次実⾏する • ⼊⼒の Iterator と、Optimizer を持つ • Iterator から⼀つずつデータを読み込み、変換し、Optimizer に かける © 2017 Retrieva, Inc. 11
  12. 12. Updater © 2017 Retrieva, Inc. 12 # examples/mnist/train_mnist.py updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars)
  13. 13. Updater © 2017 Retrieva, Inc. 13 # examples/mnist/train_mnist.py updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) Iterator から ⼀つ呼び出す
  14. 14. Updater © 2017 Retrieva, Inc. 14 # examples/mnist/train_mnist.py updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) Iterator から ⼀つ呼び出す Converter にかける (変換し、to_gpu する)
  15. 15. Updater © 2017 Retrieva, Inc. 15 # examples/mnist/train_mnist.py updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) Iterator から ⼀つ呼び出す Converter にかける (変換し、to_gpu する) Optimizer の update を呼ぶ
  16. 16. Iterator © 2017 Retrieva, Inc. 16 Trainer Updater (StandardUpdater) Iterator Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  17. 17. Iterator © 2017 Retrieva, Inc. 17 # chainer/iterators/serial_iterator.py class SerialIterator(iterator.Iterator): def __next__(self): ... return batch @property def epoch_detail(self): return self.epoch + self.current_position / len(self.dataset) # examples/mnist/train_mnist.py train, test = chainer.datasets.get_mnist() train_iter = chainer.iterators.SerialIterator(train, args.batchsize) • Iterator として、batch を返す • 回している回数の管理をする
  18. 18. Optimizer © 2017 Retrieva, Inc. 18 Trainer Updater (StandardUpdater) Iterator Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  19. 19. Optimizer • ⼊⼒データを model に forward し、返り値の loss を backward する • 最適化アルゴリズムごとに実装がある • SGD, MomentumSGD, Adam, … • Optimizer で抽象化されている © 2017 Retrieva, Inc. 19
  20. 20. Optimizer © 2017 Retrieva, Inc. 20 # chainer/training/updater.py loss_func = self.loss_func or optimizer.target optimizer.update(loss_func, *in_vars) # examples/mnist/train_mnist.py optimizer = chainer.optimizers.Adam() optimizer.setup(model) # chainer/optimizer.py class GradientMethod(Optimizer): def update(self, lossfun=None, *args, **kwds): if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', False) loss = lossfun(*args, **kwds) if use_cleargrads: self.target.cleargrads() else: self.target.zerograds() loss.backward()
  21. 21. Optimizer © 2017 Retrieva, Inc. 21 # chainer/training/updater.py loss_func = self.loss_func or optimizer.target optimizer.update(loss_func, *in_vars) # examples/mnist/train_mnist.py optimizer = chainer.optimizers.Adam() optimizer.setup(model) # chainer/optimizer.py class GradientMethod(Optimizer): def update(self, lossfun=None, *args, **kwds): if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', False) loss = lossfun(*args, **kwds) if use_cleargrads: self.target.cleargrads() else: self.target.zerograds() loss.backward() target は 渡された model (Classifier)
  22. 22. Optimizer © 2017 Retrieva, Inc. 22 # chainer/training/updater.py loss_func = self.loss_func or optimizer.target optimizer.update(loss_func, *in_vars) # examples/mnist/train_mnist.py optimizer = chainer.optimizers.Adam() optimizer.setup(model) # chainer/optimizer.py class GradientMethod(Optimizer): def update(self, lossfun=None, *args, **kwds): if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', False) loss = lossfun(*args, **kwds) if use_cleargrads: self.target.cleargrads() else: self.target.zerograds() loss.backward() target は 渡された model (Classifier) model に forward
  23. 23. Optimizer © 2017 Retrieva, Inc. 23 # chainer/training/updater.py loss_func = self.loss_func or optimizer.target optimizer.update(loss_func, *in_vars) # examples/mnist/train_mnist.py optimizer = chainer.optimizers.Adam() optimizer.setup(model) # chainer/optimizer.py class GradientMethod(Optimizer): def update(self, lossfun=None, *args, **kwds): if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', False) loss = lossfun(*args, **kwds) if use_cleargrads: self.target.cleargrads() else: self.target.zerograds() loss.backward() target は 渡された model (Classifier) model に forward backward を実⾏
  24. 24. Classifier © 2017 Retrieva, Inc. 24 Trainer Updater (StandardUpdater) Iterator Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  25. 25. Classifier • 教師あり学習⽤の model のラッパー • ⼊⼒と正解データから、loss と accuracy を計算する © 2017 Retrieva, Inc. 25
  26. 26. Classifier © 2017 Retrieva, Inc. 26 # examples/mnist/train_mnist.py model = L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss
  27. 27. Classifier © 2017 Retrieva, Inc. 27 # examples/mnist/train_mnist.py model = L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss 損失関数を指定する
  28. 28. Classifier © 2017 Retrieva, Inc. 28 # examples/mnist/train_mnist.py model = L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss 損失関数を指定する model に forward
  29. 29. Classifier © 2017 Retrieva, Inc. 29 # examples/mnist/train_mnist.py model = L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss 損失関数を指定する model に forward loss を算出 accuracy を算出
  30. 30. Classifier © 2017 Retrieva, Inc. 30 # examples/mnist/train_mnist.py model = L.Classifier(MLP(args.unit, 10)) # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss 損失関数を指定する model に forward loss を算出 accuracy を算出 loss を返す
  31. 31. Evaluator © 2017 Retrieva, Inc. 31 Trainer Updater (StandardUpdater) Iterator Optimizer Classifier model (MLP) train dataset Evaluator Iterator test dataset Extensions • dump_graph, snapshot • LogReport, PrintReport • ProgressBar • converter • loss_func • device • converter • device
  32. 32. Evaluator • テストデータに対して、loss, accuracy などを計算し、検証す る • epoch ごとに、現在まで学習された model に対して検証する • ⼤まかには、Updater と対応している © 2017 Retrieva, Inc. 32
  33. 33. Evaluator © 2017 Retrieva, Inc. 33 # examples/mnist/train_mnist.py test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # chainer/training/extensions/evaluator.py class Evaluator(extension.Extension): def evaluate(self): iterator = self._iterators['main'] target = self._targets['main'] eval_func = self.eval_func or target it = copy.copy(iterator) for batch in it: observation = {} with reporter_module.report_scope(observation): in_arrays = self.converter(batch, self.device) if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x, volatile='on') for x in in_arrays) eval_func(*in_vars)
  34. 34. Evaluator © 2017 Retrieva, Inc. 34 # examples/mnist/train_mnist.py test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # chainer/training/extensions/evaluator.py class Evaluator(extension.Extension): def evaluate(self): iterator = self._iterators['main'] target = self._targets['main'] eval_func = self.eval_func or target it = copy.copy(iterator) for batch in it: observation = {} with reporter_module.report_scope(observation): in_arrays = self.converter(batch, self.device) if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x, volatile='on') for x in in_arrays) eval_func(*in_vars) Iterator から 全部呼び出す
  35. 35. Evaluator © 2017 Retrieva, Inc. 35 # examples/mnist/train_mnist.py test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # chainer/training/extensions/evaluator.py class Evaluator(extension.Extension): def evaluate(self): iterator = self._iterators['main'] target = self._targets['main'] eval_func = self.eval_func or target it = copy.copy(iterator) for batch in it: observation = {} with reporter_module.report_scope(observation): in_arrays = self.converter(batch, self.device) if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x, volatile='on') for x in in_arrays) eval_func(*in_vars) Iterator から 全部呼び出す Converter にかける (変換し、to_gpu する)
  36. 36. Evaluator © 2017 Retrieva, Inc. 36 # examples/mnist/train_mnist.py test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # chainer/training/extensions/evaluator.py class Evaluator(extension.Extension): def evaluate(self): iterator = self._iterators['main'] target = self._targets['main'] eval_func = self.eval_func or target it = copy.copy(iterator) for batch in it: observation = {} with reporter_module.report_scope(observation): in_arrays = self.converter(batch, self.device) if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x, volatile='on') for x in in_arrays) eval_func(*in_vars) Iterator から 全部呼び出す Converter にかける (変換し、to_gpu する) model に forward
  37. 37. 説明していないこと • Reporter 周り • Evaluator で、eval_func しているが戻り値を使っていない理由 • (ざっくり⾔うと)Classifier 内で、reporter に loss, accuracy を登 録している • Extension 周り © 2017 Retrieva, Inc. 37
  38. 38. 第2部 NStrepLSTM との格闘 © 2017 Retrieva, Inc. 38
  39. 39. NStepLSTM とは • RNN のための、Chainer 1.16.0 で導⼊された Link • cuDNN の恩恵を受けて、⾼速に動く • 既存の LSTM と使い⽅が違う • 既存の LSTM のサンプルは examples/ptb/train_ptb.py © 2017 Retrieva, Inc. 39
  40. 40. RNN • Recurrent Neural Network • 並び⽅に意味のある、「系列データ」を扱う場合に⽤いられる • 応⽤例:⽂章の推定、⾳声認識、変動する数値の推定 • 例:⽂章が途中まで与えられた時、次の単語を予測する問題 © 2017 Retrieva, Inc. 40 私 は ⽩い ⽝ が ? x1 x2 x3 x4 x5 y1 y2 y3 y4 y5 x1〜x5を⼊⼒データと して、y5を推定する。
  41. 41. LSTM と NStepLSTM データ1 1 2 データ1ラベル A B データ2 1 2 3 データ2ラベル A B C © 2017 Retrieva, Inc. 41 • LSTM(逐次渡す) • x1: Variable[1, 1] • t1: Variable[B, B] • x2: Variable[2, 2] • t2: Variable[0, C] • NStepLSTM(⼀度に渡す) • xs: [Variable[1,2], Variable[1,2,3]] • ts: [Variable[A,B], Variable[A,B,C]]
  42. 42. LSTM と NStepLSTM データ1 1 2 データ1ラベル A B データ2 1 2 3 データ2ラベル A B C © 2017 Retrieva, Inc. 42 • LSTM(逐次渡す) • x1: Variable[1, 1] • t1: Variable[B, B] • x2: Variable[2, 2] • t2: Variable[0, C] • NStepLSTM(⼀度に渡す) • xs: [Variable[1,2], Variable[1,2,3]] • ts: [Variable[A,B], Variable[A,B,C]] ⻑さが合っていない時、 0 などで埋める必要がある ⻑さを合わせなくて良い Variable の list を渡す
  43. 43. NStepLSTM サンプル © 2017 Retrieva, Inc. 43 class RNNNStepLSTM(chainer.Chain): def __init__(self, n_layer, n_units, train=True): super(RNNNStepLSTM, self).__init__( l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True), ) self.n_layer = n_layer self.n_units = n_units def __call__(self, xs): xp = self.xp hx = chainer.Variable(xp.zeros( (self.n_layer, len(xs), self.n_units), dtype=xp.float32)) cx = chainer.Variable(xp.zeros( (self.n_layer, len(xs), self.n_units), dtype=xp.float32)) hy, cy, ys = self.l1(hx, cx, xs, train=self.train)
  44. 44. NStepLSTM サンプル © 2017 Retrieva, Inc. 44 class RNNNStepLSTM(chainer.Chain): def __init__(self, n_layer, n_units, train=True): super(RNNNStepLSTM, self).__init__( l1 = L.NStepLSTM(n_layer, n_units, n_units, 0.5, True), ) self.n_layer = n_layer self.n_units = n_units def __call__(self, xs): xp = self.xp hx = chainer.Variable(xp.zeros( (self.n_layer, len(xs), self.n_units), dtype=xp.float32)) cx = chainer.Variable(xp.zeros( (self.n_layer, len(xs), self.n_units), dtype=xp.float32)) hy, cy, ys = self.l1(hx, cx, xs, train=self.train) レイヤー数、ユニット数、 Dropout を指定する パラメータの 初期状態を作成し、 渡す Variable のリストを ⼊⼒する 出⼒も Variable の リスト
  45. 45. NStepLSTM と、標準的な Trainer の齟齬 • 標準的な Trainer の構成では、model には Variable を渡す • NStepLSTM では、「Variable のリスト」を渡さなければいけない © 2017 Retrieva, Inc. 45 # chainer/training/updater.py class StandardUpdater(Updater): def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) ... if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) Variableを作成し、 そのまま渡している
  46. 46. NStepLSTM と、標準的な Trainer の齟齬 • loss の計算を、既存のメソッドに任せられない © 2017 Retrieva, Inc. 46 # chainer/links/model/classifier.py class Classifier(link.Chain): def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): def __call__(self, *args): self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss NStepLSTM の出⼒だと、 y は Variable のリスト accuracy も同様
  47. 47. ptb/train_ptb.py を NStepLSTM で • https://github.com/kei-s/chainer-ptb- nsteplstm/blob/master/train_ptb_nstep.py • できるだけ構成を同じにして(Trainer の上で)、 train_ptb.py を NStepLSTM を使って実装してみる • Disclaimer • testモード(100件)では完⾛したけど、全データは⾛らせていない • 無駄なコードはありそう… © 2017 Retrieva, Inc. 47
  48. 48. ptb/train_ptb.py を NStepLSTM で • モデル • EmbedID にかけるため、⼀度 concat し、split_axis でまた分ける • 系列のそれぞれの要素を Linear にかける • Iterator • bprop_len を Iterator に渡し、バッチで系列化したものを返す • Converter • NStepLSTM を使った seq2seq https://github.com/pfnet/chainer/pull/2070 を参照 • Updater • 系列をそのままモデルに渡すように変更 • Evaluator • 系列をそのままモデルに渡すように変更 • Lossfun • softmax_cross_entropy をそれぞれの系列に対してかけ、⾜し合わせる © 2017 Retrieva, Inc. 48
  49. 49. ptb/train_ptb.py を NStepLSTM で • ⾼速化したか? • Estimated time では 6時間 → 3時間 • とはいえ、実際に⾛らせると 3時間以上かかりそう • Estimated Time に Evaluator 部分が考慮されてなさそう • ⾼速化のために必要なこと • Chainer もしくは Numpy の世界で処理を終わらせること • Python の世界でループを回すとかなり遅くなる • Lossfun でループを回しているが、concat して渡してもよさそう(互換性が微 妙な予感) © 2017 Retrieva, Inc. 49
  50. 50. © 2017 Retrieva, Inc.

×