7. 学習データ準備
def pickup_files(git_root, extensions):
for (dirpath, dirnames, filenames) in os.walk(git_root):
for fn in filenames:
for ext in extensions:
if fn.endswith(ext):
yield os.path.join(dirpath, fn)
def writeout(git_root, extensions, out_file):
for filename in pickup_files(git_root, extensions):
with open(filename, 'r') as f:
out.write(f.read() + 'nnn')
print('{} read and written out.'.format(filename))
10MBほどになります
10. 学習に使うメインループ
# セッション(実行マネージャ)貼る
with tf.Session() as sess:
tf.initialize_all_variables().run()
for e in xrange(args.num_epochs):
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
data_loader.reset_batch_pointer()
state = model.initial_state.eval()
for b in xrange(data_loader.num_batches):
# データ取ってくる
x, y = data_loader.next_batch()
feed = {model.input_data: x, model.targets: y, model.initial_state: state}
# 定義済のオペレーションを走らせる
train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)