Speaker: Yi Wei
Title: Write debuggable Tensorflow code and find bugs in the herd
Abstract: Tensorflow is powerful but difficult to use. What usually happens is after you wired up the Tensorflow graph, it is difficult to verify its correctness, and when a problem happens, it is difficult to debug. While working on various deep reinforcement learning algorithms at Prowler, I developed a set of techniques to mitigate the pain of writing, understanding and testing Tensorflow code. These techniques enable us to produce machine learning modules in a much faster speed, find bugs with minimal effort and most importantly, deliver learners whose correctness are validated against their definitions, mainly from corresponding research papers.
Bio: Yi Wei is a senior machine learning engineer at Prowler. He focuses on deep reinforcement learning algorithms for automated trading. Prior to Prowler, he was a co-founder of the CTX Fintech company which provides algo-trading infrastructure. He also worked at Microsoft Research Cambridge for three years; developed the CodeSnippets technology that synthesizes code from users’ natural language queries and from publicly available code repositories. The Bing search engine productized this technology in the tech search section and reported to have one of its core metrics, session success rate, improved by 4%, a hard to achieve improvement in the world of commercial search engines. He won the Microsoft Research Technology Transfer Award for the CodeSnippets project. Yi Wei got his PhD from ETH Zurich in 2012 on the topic of automated testing and bug fixing.
Debug Tensorflow Code with Specifications and Assertions
1. Find bugs in the herd with
debuggable Tensorflow code
Yi Wei
yiwei@prowler.io
1
2. Tensorflow code is difficult to debug and verify
● Tensor values are multi-dimensional arrays
● TensorBoard graph visualization has hundred of nodes and edges
● Many (or not enough) tips from the Internet, but I never know whether my
code is correct with those tipis
2
3. Specification to the rescue
● Reasoning Tensorflow code is difficult since the debugger does
not know what is correct.
● Specification defines correctness of the code.
● Correctness w.r.t. algorithm definition, not whether the code can
learn a model
3
Ask not what the debugger can do for you,
ask what you can do for the debugger
4. Three assertion techniques to verify correctness
● Tensor shape assertions to validate data shapes
● Tensor dependency assertions to validate graph structure
● Tensor equation assertions to numerical calculations
4
5. Technique 1: shape assertions
Write an assert to check the shape of every tensor you introduce.
prediction_tensor = q_function.output_tensor
assert prediction_tensor.shape.to_list() == [batch_size, action_dimension]
target_tensor = reward_tensor + discount * bootstrapped_tesnor
assert target_tensor.shape.to_list() == [batch_size, action_dimension]
loss_tensor = tf.losses.mean_squared_error(target_tensor, prediction_tensor)
assert loss_tensor.shape.to_list() == []
5
6. Next we need to validate graph structure, but how?
TensorBoard gives complicated visualization, not practical for most of us
6
7. Technique 2: Tensor group dependency
7
We developed a Python package TensorGroupDependency:
● Visualizes part of the graph involving only your tensors
● Helps you to check tensor dependency correctness
● Automatically generates tensor graph structural assertions.
● This is the key step to make the whole process practical!
8. Use of TensorGroupDependency
d = TensorGroupDependency()
d.add(q_function, 'q_function')
d.add(q_function.output_tensor, 'q_value_tensor')
d.add(prediction_tensor, 'prediction_tensor')
d.add(target_tensor, 'target_tensor')
d.add(loss_tensor, 'loss_tensor')
dot = d.generate_dot_representation()
print(dot)
8
9. Visualization from TensorGroupDependency
● Tensors as nodes
● Dependency between tensors as edges
● You must explain why edges exist
● Have assertions automatically generated
9
13. Visualization for the StarCraft2 learner
Graphs become smaller and smaller because of composability.
13
14. Open source TensorGroupDependency
● TensorGroupDependency is the key component that
make the whole process practical
● We are in preparation of open-sourcing
TensorGroupDependency
● Drop me an email at yiwei@prowler.io if interested
14
15. Technique 3: tensor equations
Write an assertion to check every equation in your algorithm.
_, prediction, target, loss = sess.run(
[parameter_update_operations, prediction_tensor, target_tensor, loss_tensor],
feed_dict={})
mean_square_error = np.mean(np.power(target - prediction, 2))
np.testing.assert_almost_equal(loss, mean_square_error, decimal=1)
15
17. Reasons for the bug-detecting effectiveness
● Practical specification writing
○ Specification defines correctness, no way around it.
○ Explicitly writing down specification helps you and the debugger
○ Practical is the keyword
● Fault localization
○ Locating where a fault originates in Tensorflow code is difficult
○ At each stage of specification, you only need to focus on places within that stage.
● Clear to-do list style engineering process
○ Each assertion stage has finite steps bounded by the tensors you introduce, usually a dozen
○ You know exactly when the verification process ends -- when you’ve validated your code!
○ When people know the exact steps, they are a lot more efficient.
17
18. Assertions enable advanced testing
● Ingredients of a test: input construction and oracle
● Machine learning code uses numbers as input
● Since we already have the oracle, generating tests are easy
18