Successfully reported this slideshow.
We use your LinkedIn profile and activity data to personalize ads and to show you more relevant ads. You can change your ad preferences anytime.
Learning to Remember Rare Events
Paper is appeared in ICLR 2017, https://arxiv.org/abs/1703.03129
Authors:
Łukasz Kaiser, ...
What we can learn from this paper?
1. Memory‑augmented deep neural network
2. Two tasks:
One‑shot learning (Omniglot datas...
Problem Definition
rare events v.s. on average
Image from "One Shot Learning" (Jisung Kim @ TensorFlow‑KR 2nd Meetup Light...
8 Tactics To Combat Imbalanced Training Data
1. Collect More Data
2. Try Changing Your Performance Metric
3. Try Resamplin...
Problem Definition (for rare events)
Deep Neural Networks
Extend the training data
Re‑train them to handle such rare or ne...
Key Concepts
Deep Neural Networks (+ Memory Module)
6
Previous Works
Meta‑Learning with Memory Augmented Neural Networks
Idea: Write the pair of "image and label" into the memo...
Memory module
Define a memory of size  memory-size as a triple:
M = (K , V , A )
m:  memory-size , k:  key-size .
Key: act...
Memory module (query)
Memory query q is a vector of size  key-size :
q =R , ∣∣q∣∣ = 1
The nearest neighbor(*) of q in M :
...
Memory module (query)
Cosine similarity: d = q ⋅ K[n ]
Return  softmax (d ⋅ τ, ..., d ⋅ τ)
Inverse of softmax temperature:...
[Note] Softmax temperature, τ
The idea is to control randomness of predictions
: Softmax outputs are more close to each ot...
Memory module (episode)
Slide from "Meta‑learning with memory‑augmented neural networks" (Slideshare, H. Kim) 12
Memory module (train)
Memory loss
Query q and the correct desired (supervised) value v.
Classification: v would be the cla...
Memory module (train)
loss(q, v, M) = [q ⋅ K[n ] − q ⋅ K[n ] + α]
K[n ]: positive neightbor, V [n ] = v
K[n ]: negative ne...
Memory module (Update)
Case V [n ] = v:
K[n ] ←
A[n ] ← 0
Case V [n ] ≠ v:
if memory has empty space at n ,
assign n with ...
Memory module (train & update)
16
Experiments (Evaluation)
1. Evaluation on Omniglot dataset
2. Evaluation on synthetic task
3. Evaluation on English‑German...
Experiments (Omniglot Dataset)
18
Experiments (Omniglot Dataset)
Omniglot dataset
This dataset contains  1623 different handwritten characters
from  50 diff...
Experiments (Omniglot Dataset)
CNN Architecture
(Conv, ReLU), (Conv, ReLU), pool,
(Conv, ReLU), (Conv, ReLU), pool, FC, FC...
Experiments (Omniglot Dataset)
 way : different alphabets
 shot : different characters
21
Experiments (GNMT)
Decoder path
Key: result of attention a
Combine value and LSTM output (at decoder time‑step)
t
22
Experiments (GNMT)
23
Experiments (GNMT)
Convolutional Gated Recurrent Unit (CGRU)
For more information: Read the Lunit tech blog
24
Conclusions
Long‑term memory module
Embedding input with a simple CNN (LeNet)
Returning k‑nn could be used for other layer...
Code Review (Github)
1.  data_utils.py : Data loading and other utilities.
2.  train.py : Script for training model.
3.  m...
Quick Start
1) First download and set‑up Omniglot data by running
python data_utils.py
2) Then run the training script:
py...
3) The first validation batch may look like this (although it is
noisy):
0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516,
3-sh...
0) Basic parameters
rep_dim: 128, dimension of keys to use in memory
episode_length: 100, length of episode
episode_width:...
1) data_utils.py
def preprocess_omniglot():
# Download and prepare raw Omniglot data.
def maybe_download_data():
# Downloa...
1) Tips from data_utils.py
 logging 으로메세지를관리한다.  level 조절가능.
 pickle 로dump해서사용한다. (TFrecord, queue는..?)
간단한외부명령은 subproces...
2) train.py
def data_utils.get_data():
# Get data in form suitable for episodic training.
# Returns: Train and test data a...
2) Tips from train.py
기본적인파라미터는 tf.flags 로전달
학습과 관련된내용들은 logging 으로메세지전달
 assert 활용: episode 길이오류확인
train / validation 동시수...
3) model.py
class LeNet(object):
# Standard CNN architecture
class Model(object):
# Model for coordinating between CNN emb...
3) model.py
Line 152‑158,  core_builder() :
embeddings = self.embedder.core_builder(x)
if keep_prob < 1.0:
embeddings = tf...
3) Tips from model.py
 core_builder() : 기존네트워크에memory 추가
입력영상 x 에대해 LeNet 을이용해embedding vector 생성
 weight ,  bias 는 tf.get...
4) memory.py
class Memory(object):
def get_hint_pool_idxs(...):
# Get small set of idxs to compute nearest neighbor
# quer...
4) Tips from memory.py
 Memory 와 LSHMemory 중선택가능,  memory 사용권고.
논문의memory 동작을직관적으로구현
 memory_size 와 key_size 만변경하면거의대부분의네트...
Appendix (Reviews)
1. Lunit Tech Blog (by Hyo‑Eun Kim) (Link)
2. OpenReview (ICLR2017) (Link)
3. BAIR Blog: "Learning to L...
Appendix (Implementations)
1. TensorFlow/models (GoogleBrain) (Github)
40
Upcoming SlideShare
Loading in …5
×

[PR12] PR-036 Learning to Remember Rare Events

2,014 views

Published on

Paper review for "Learning to Remember Rare Events (ICLR2017)"

Published in: Engineering
  • Be the first to comment

[PR12] PR-036 Learning to Remember Rare Events

  1. 1. Learning to Remember Rare Events Paper is appeared in ICLR 2017, https://arxiv.org/abs/1703.03129 Authors: Łukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio (Google Brain)  Reviewed by Taegyun Jeon  1
  2. 2. What we can learn from this paper? 1. Memory‑augmented deep neural network 2. Two tasks: One‑shot learning (Omniglot dataset) Life‑long one‑shot learning (large‑scale machine translation) 3. TensorFlow implementation for the one‑shot learning Official code from Google Brain using TensorFlow 2
  3. 3. Problem Definition rare events v.s. on average Image from "One Shot Learning" (Jisung Kim @ TensorFlow‑KR 2nd Meetup Lighting Talk) 3
  4. 4. 8 Tactics To Combat Imbalanced Training Data 1. Collect More Data 2. Try Changing Your Performance Metric 3. Try Resampling Your Dataset 4. Try Generate Synthetic Samples 5. Try Different Algorithms 6. Try Penalized Models 7. Try a Different Perspective 8. Try Getting Creative "8 Tactics to Combat Imbalanced Classes in Your Machine Learning Dataset" @ Machine Learning Mastery 4
  5. 5. Problem Definition (for rare events) Deep Neural Networks Extend the training data Re‑train them to handle such rare or new events Very SLOW!! (gradient‑based optimization) Humans (life‑long fashion) Learn from single example 5
  6. 6. Key Concepts Deep Neural Networks (+ Memory Module) 6
  7. 7. Previous Works Meta‑Learning with Memory Augmented Neural Networks Idea: Write the pair of "image and label" into the memory Matching Networks for One Shot Learning Idea: Train fully end‑to‑end nearest neighbor classifier Note from A. Karpathy 7
  8. 8. Memory module Define a memory of size  memory-size as a triple: M = (K , V , A ) m:  memory-size , k:  key-size . Key: activations of a chosen layer of a neural network. Value: ground‑truth targets for the given example. Age: track the ages of the items stored in memory. m×k m m 8
  9. 9. Memory module (query) Memory query q is a vector of size  key-size : q =R , ∣∣q∣∣ = 1 The nearest neighbor(*) of q in M : NN(q, M) = arg q ⋅ K[i]. Given a query q, Memory M will compute k‑NN: (n , ..., n ) = NN (q, M) Return the main result. the value V [n ] k i max 1 k k 1 (*) Since the keys are normalized, the nearest neighbor w.r.t. cosine similarity. 9
  10. 10. Memory module (query) Cosine similarity: d = q ⋅ K[n ] Return  softmax (d ⋅ τ, ..., d ⋅ τ) Inverse of softmax temperature: τ = 40 i i 1 k 10
  11. 11. [Note] Softmax temperature, τ The idea is to control randomness of predictions : Softmax outputs are more close to each other : Softmax outputs are more and more "hardmax" For a low temperature (τ → 0 ), the probability of the output with the highest expected reward tends to 1. + 11
  12. 12. Memory module (episode) Slide from "Meta‑learning with memory‑augmented neural networks" (Slideshare, H. Kim) 12
  13. 13. Memory module (train) Memory loss Query q and the correct desired (supervised) value v. Classification: v would be the class label. Seq2Seq: v would be the output token of the current time step. 13
  14. 14. Memory module (train) loss(q, v, M) = [q ⋅ K[n ] − q ⋅ K[n ] + α] K[n ]: positive neightbor, V [n ] = v K[n ]: negative neightbor, V [n ] ≠ v α: Margin to make loss as zero b p + p p b b 14
  15. 15. Memory module (Update) Case V [n ] = v: K[n ] ← A[n ] ← 0 Case V [n ] ≠ v: if memory has empty space at n , assign n with n if not, n = max(A[n ]) K[n ] ← q, V [n ] ← v, and A[n ] ← 0. p 1 ∣∣q+k[n ]∣∣1 q+k[n ]1 1 b empty ′ empty ′ k ′ ′ ′ 15
  16. 16. Memory module (train & update) 16
  17. 17. Experiments (Evaluation) 1. Evaluation on Omniglot dataset 2. Evaluation on synthetic task 3. Evaluation on English‑German translation model Qualitative side: rarely‑occurring words Quantitative side: BLEU score 17
  18. 18. Experiments (Omniglot Dataset) 18
  19. 19. Experiments (Omniglot Dataset) Omniglot dataset This dataset contains  1623 different handwritten characters from  50 different alphabets. Each of the 1623 characters was drawn online via Amazon's Mechanical Turk by 20 different people. Each image is paired with stroke data, a sequences of  [x,y,t]  coordinates with time  (t) in milliseconds. Stroke data is available in MATLAB files only. Omniglot dataset for one‑shot learning (github): https://github.com/brendenlake/omniglot 19
  20. 20. Experiments (Omniglot Dataset) CNN Architecture (Conv, ReLU), (Conv, ReLU), pool, (Conv, ReLU), (Conv, ReLU), pool, FC, FC Memory module Output layer (Prediction) 20
  21. 21. Experiments (Omniglot Dataset)  way : different alphabets  shot : different characters 21
  22. 22. Experiments (GNMT) Decoder path Key: result of attention a Combine value and LSTM output (at decoder time‑step) t 22
  23. 23. Experiments (GNMT) 23
  24. 24. Experiments (GNMT) Convolutional Gated Recurrent Unit (CGRU) For more information: Read the Lunit tech blog 24
  25. 25. Conclusions Long‑term memory module Embedding input with a simple CNN (LeNet) Returning k‑nn could be used for other layers. 25
  26. 26. Code Review (Github) 1.  data_utils.py : Data loading and other utilities. 2.  train.py : Script for training model. 3.  memory.py : Memory module for storing "nearest neighbors". 4.  model.py : Model using memory component. 26
  27. 27. Quick Start 1) First download and set‑up Omniglot data by running python data_utils.py 2) Then run the training script: python train.py --memory_size=8192 --batch_size=16 --validation_length=50 --episode_width=5 --episode_length=30 27
  28. 28. 3) The first validation batch may look like this (although it is noisy): 0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516, 3-shot: 0.604, 4-shot: 0.656, 5-shot: 0.684 4) At step 500 you may see something like this: 0-shot: 0.036, 1-shot: 0.836, 2-shot: 0.900, 3-shot: 0.940, 4-shot: 0.944, 5-shot: 0.916 5) At step 4000 you may see something like this: 0-shot: 0.044, 1-shot: 0.960, 2-shot: 1.000, 3-shot: 0.988, 4-shot: 0.972, 5-shot: 0.992 28
  29. 29. 0) Basic parameters rep_dim: 128, dimension of keys to use in memory episode_length: 100, length of episode episode_width: 5, number of distinct labels in a single episode memory_size: None, number of slots in memory. batch_size: 16, batch size num_episodes: 100000, number of training episodes validation_frequency: 20, every so many training episodes assess validation accuracy validation_length: 10, number of episodes to use to compute validation accuracy seed: 888, random seed for training sampling save_dir: '', directory to save model to use_lsh: False, use locality‑sensitive hashing (NOTE: not fully tested) 29
  30. 30. 1) data_utils.py def preprocess_omniglot(): # Download and prepare raw Omniglot data. def maybe_download_data(): # Download Omniglot repo if it does not exist. def write_datafiles(): # Load and preprocess images from a directory and # write them to a file. def crawl_directory(): # Crawls data directory and returns stuff. def resize_images(): # Resize images to new dimensions. 30
  31. 31. 1) Tips from data_utils.py  logging 으로메세지를관리한다.  level 조절가능.  pickle 로dump해서사용한다. (TFrecord, queue는..?) 간단한외부명령은 subprocess 로실행한다. train dataset만 augment (rotation) 수행(0, 90, 180, 270도)  resizing 수행(기존: 105, 변환: 28) OUTPUT: train_omni.pkl (733M), test_omni.pkl (126M) 31
  32. 32. 2) train.py def data_utils.get_data(): # Get data in form suitable for episodic training. # Returns: Train and test data as dictionaries mapping # label to list of examples. class Trainer(): def run(): self.sample_episode_batch() outputs = self.model.episode_step() 32
  33. 33. 2) Tips from train.py 기본적인파라미터는 tf.flags 로전달 학습과 관련된내용들은 logging 으로메세지전달  assert 활용: episode 길이오류확인 train / validation 동시수행(20 : 1 비율) 33
  34. 34. 3) model.py class LeNet(object): # Standard CNN architecture class Model(object): # Model for coordinating between CNN embedder and # Memory module. 34
  35. 35. 3) model.py Line 152‑158,  core_builder() : embeddings = self.embedder.core_builder(x) if keep_prob < 1.0: embeddings = tf.nn.dropout(embeddings, keep_prob) memory_val, _, teacher_loss = self.memory.query( embeddings, y, use_recent_idx=use_recent_idx) loss, y_pred = self.classifier.core_builder( memory_val, x, y) return loss + teacher_loss, y_pred 35
  36. 36. 3) Tips from model.py  core_builder() : 기존네트워크에memory 추가 입력영상 x 에대해 LeNet 을이용해embedding vector 생성  weight ,  bias 는 tf.get_variable 로미리생성 model의각 기능을최대한세분화 36
  37. 37. 4) memory.py class Memory(object): def get_hint_pool_idxs(...): # Get small set of idxs to compute nearest neighbor # queries on. def query(...): # Queries memory for nearest neighbor. class LSHMemory(Memory): # Memory employing locality sensitive hashing. # Note: Not fully tested. 37
  38. 38. 4) Tips from memory.py  Memory 와 LSHMemory 중선택가능,  memory 사용권고. 논문의memory 동작을직관적으로구현  memory_size 와 key_size 만변경하면거의대부분의네트워크에 접목가능 38
  39. 39. Appendix (Reviews) 1. Lunit Tech Blog (by Hyo‑Eun Kim) (Link) 2. OpenReview (ICLR2017) (Link) 3. BAIR Blog: "Learning to Learn" (by Chelsea Finn) (Link) 4. Learning to remember rare events (by Hongbae Kim) (Slideshare) 5. One Shot Learning (by Jisung Kim) (Slideshare) 39
  40. 40. Appendix (Implementations) 1. TensorFlow/models (GoogleBrain) (Github) 40

×