SlideShare a Scribd company logo
1 of 24
Download to read offline
Stochastic Computation Graphs, and
Pathwise Derivative Policy Gradient Methods
John Schulman
August 27, 2017
Stochastic Computation Graphs
Motivation
Gradient estimation is at the core of much of modern machine learning
Backprop is mechanical: this us more freedom to tinker with architectures
and loss functions
RL (and certain probabilistic models) involve components you can’t
backpropagate through—need REINFORCE-style estimators logprob *
advantage
Often we need a combination of backprop-style and REINFORCE-style
gradient estimation
RNN policies π(at | ht), where ht = f (ht−1, at−1, ot)
Instead of reward, we have a known differentiable function, e.g. classifier
with hard attention. Objective = logprob of correct label.
Each paper to propose such a model provides a new convoluted derivation
Gradients of Expectations
Want to compute θE [F]. Where’s θ?
“Inside” distribution, e.g., Ex∼p(· | θ) [F(x)]
θEx [f (x)] = Ex [f (x) θ log px (x; θ)] .
Score function (SF) estimator
Example: REINFORCE policy gradients, where x is the trajectory
Inside expectation: Ez∼N(0,1) [F(θ, z)]
θEz [f (x(z, θ))] = Ez [ θf (x(z, θ))] .
Pathwise derivative (PD) estimator
Example: neural net with dropout
Often, we can reparametrize, to change from one form to another
Not a property of problem: equivalent ways of writing down the same
stochastic process
What if F depends on θ in complicated way, affecting distribution and F?
M. C. Fu. “Gradient estimation”. In: Handbooks in operations research and management science 13 (2006), pp. 575–616
Stochastic Computation Graphs
Stochastic computation graph is a DAG, each node corresponds to a
deterministic or stochastic operation
Can automatically derive unbiased gradient estimators, with variance
reduction
John Schulman1, Nicolas Heess2, Théophane Weber2,
1University of California, Berkeley 2Google Deep
Generalize backpropagation to deal with random variables
that we can’t differentiate through.
Why can’t we differentiate through them?
1) discrete random variables
2) unmodeled external world, in RL / control
Computation Graphs Stochastic Computation Graphs
stochastic node
Motivation / Applications
policy gradients in 

reinforcement learning
variational inference
distribution is different from the distribution we are evaluating: for parameter ✓ 2 ⇥, ✓ = ✓old is
used for sampling, but we are evaluating at ✓ = ✓new.
Ev c | ✓new
[ˆc] = Ev c | ✓old
2
6
6
4ˆc
Y
v c,
✓ D
v
Pv(v | DEPSv✓, ✓new)
Pv(v | DEPSv✓, ✓old)
3
7
7
5 (17)
 Ev c | ✓old
2
6
6
4ˆc
0
B
B
@log
0
B
B
@
Y
v c,
✓ D
v
Pv(v | DEPSv✓, ✓new)
Pv(v | DEPSv✓, ✓old)
1
C
C
A + 1
1
C
C
A
3
7
7
5 (18)
where the second line used the inequality x log x + 1, and the sign is reversed since ˆc is negative.
Summing over c 2 C and rearranging we get
ES | ✓new
"
X
c2C
ˆc
#
 ES | ✓old
"
X
c2C
ˆc +
X
v2S
log
✓
p(v | DEPSv✓, ✓new)
p(v | DEPSv✓, ✓old)
◆
ˆQv
#
(19)
= ES | ✓old
"
X
v2S
log p(v | DEPSv✓, ✓new) ˆQv
#
+ const (20)
Equation (20) allows for majorization-minimization algorithms (like the EM algorithm) to be used
to optimize with respect to ✓. In fact, similar equations have been derived by interpreting rewards
(negative costs) as probabilities, and then taking the variational lower bound on log-probability (e.g.,
[24]).
C Examples
C.1 Generalized EM Algorithm and Variational Inference.
The generalized EM algorithm maximizes likelihood in a probabilistic model with latent variables
[18]. Suppose the probabilistic model defines a probability distribution p(x, z; ✓) where x is ob-
served, z is a latent variable, and ✓ is a parameter of the distribution. The generalized EM algorithm
it’s all about gradients of expectations!
Just Differ
@
@✓
E
2
6
4
X
cost node c
c
3
7
5 =
X
stochastic
determini
influence
baseline(
=
@
@✓
2
6
6
6
6
6
6
4
X
stochastic node x
deterministically
influenced by ✓
0
B
B
B
@
log p(x
Under certain condition
cost, related to variation
Equivalently, use backp
Overview
L L
Why’s it useful?
1) no need to rederive every time
2) enable generic software
θ
φ
base
=
@
@✓
2
6
6
6
6
6
6
4
X
stochastic node x
deterministically
influenced by ✓
0
B
B
B
@
lo
Score function estimator:
@
@✓
Ex
Pathwise derivative estim
@
@✓
Ez
@a
@✓
@ log p(b | a, d)
@a
@ log p(b | a, d)
@a
@c
@b
@e
@d
@ log p(d | )
@
(c + e)
@c
@b
@b
@d
@ log p(b | a, d)
@d
c
@d
@
✓
@ log p(b | a, d)
@d
c +
@e
@d
◆
2
@a
@✓
@ log p(b | a, d)
@a
c
@ log p(b | a, d)
@a
c
@c
@b
@e
@d
@ log p(d | )
@
(c + e)
@c
@b
@b
@d
@ log p(b | a, d)
@d
c
@d
@
✓
@ log p(b | a, d)
@d
c +
@e
@d
◆
2
J. Schulman, N. Heess, T. Weber, et al. “Gradient Estimation Using Stochastic Computation Graphs”. In: NIPS. 2015
Example: Hard Attention
Look at image (o1)
Determine crop window (a)
Get cropped image (o2)
Version 1
Output label (y)
Get reward of 1 for correct label
Version 2
Output parameter specifying distribution over label (κ)
Reward = prob or logprob of correct label (log p)
o1 o2
θ
a y r
o1 o2
θ
a κ log p
Abstract Example
a
d
cb
e
θ
φ
L = c + e. Want to compute d
dθ E [L] and d
dφ E [L].
Treat stochastic node b) as constants, and introduce losses
logprob ∗ (futurecost − optionalbaseline) at each stochastic node
Obtain unbiased gradient estimate by differentiating surrogate:
Surrogate(θ, ψ) = c + e
(1)
+ log p(ˆb | a, d)ˆc
(2)
(1): how parameters influence cost through deterministic dependencies
(2): how parameters affect distribution over random variables.
Results
General method for calculating gradient estimator
Surrogate loss version: each stochastic node gives a loss expression
logprob(node value) * (sum of “downstream” costs treated as constants).
Add these to the sum of costs treated as variables.
Alternative generalized backprop: as usual, do reverse topological sort.
Instead of backpropagating using Jacobian, backpropagate gradlogprob *
futurecosts.
Generalized version of baselines for variance reduction (without bias):
logp(node) * (sum of downstream costs treated as constants minus
baseline(nondescendents of node))
Pathwise Derivative Policy Gradient Methods
Deriving the Policy Gradient, Reparameterized
Episodic MDP:
θ
s1 s2 . . . sT
a1 a2 . . . aT
RT
Want to compute θE [RT ]. We’ll use θ log π(at | st; θ)
Deriving the Policy Gradient, Reparameterized
Episodic MDP:
θ
s1 s2 . . . sT
a1 a2 . . . aT
RT
Want to compute θE [RT ]. We’ll use θ log π(at | st; θ)
Reparameterize: at = π(st, zt; θ). zt is noise from fixed distribution.
θ
s1 s2 . . . sT
a1 a2 . . . aT
z1 z2 . . . zT
RT
Deriving the Policy Gradient, Reparameterized
Episodic MDP:
θ
s1 s2 . . . sT
a1 a2 . . . aT
RT
Want to compute θE [RT ]. We’ll use θ log π(at | st; θ)
Reparameterize: at = π(st, zt; θ). zt is noise from fixed distribution.
θ
s1 s2 . . . sT
a1 a2 . . . aT
z1 z2 . . . zT
RT
Only works if P(s2 | s1, a1) is known ¨
Using a Q-function
θ
s1 s2 . . . sT
a1 a2 . . . aT
z1 z2 . . . zT
RT
d
dθ
E [RT ] = E
T
t=1
dRT
dat
dat
dθ
= E
T
t=1
d
dat
E [RT | at]
dat
dθ
= E
T
t=1
dQ(st, at)
dat
dat
dθ
= E
T
t=1
d
dθ
Q(st, π(st, zt; θ))
SVG(0) Algorithm
Learn Qφ to approximate Qπ,γ
, and use it to compute gradient estimates.
N. Heess, G. Wayne, D. Silver, et al. “Learning continuous control policies by stochastic value gradients”. In: NIPS. 2015
SVG(0) Algorithm
Learn Qφ to approximate Qπ,γ
, and use it to compute gradient estimates.
Pseudocode:
for iteration=1, 2, . . . do
Execute policy πθ to collect T timesteps of data
Update πθ using g ∝ θ
T
t=1 Q(st, π(st, zt; θ))
Update Qφ using g ∝ φ
T
t=1(Qφ(st, at) − ˆQt)2
, e.g. with TD(λ)
end for
N. Heess, G. Wayne, D. Silver, et al. “Learning continuous control policies by stochastic value gradients”. In: NIPS. 2015
SVG(1) Algorithm
θ
s1 s2 . . . sT
a1 a2 . . . aT
z1 z2 . . . zT
RT
Instead of learning Q, we learn
State-value function V ≈ V π,γ
Dynamics model f , approximating st+1 = f (st, at) + ζt
Given transition (st, at, st+1), infer ζt = st+1 − f (st, at)
Q(st, at) = E [rt + γV (st+1)] = E [rt + γV (f (st, at) + ζt)], and at = π(st, θ, ζt)
SVG(∞) Algorithm
θ
s1 s2 . . . sT
a1 a2 . . . aT
z1 z2 . . . zT
RT
Just learn dynamics model f
Given whole trajectory, infer all noise variables
Freeze all policy and dynamics noise, differentiate through entire deterministic
computation graph
SVG Results
Applied to 2D robotics tasks
Overall: different gradient estimators behave similarly
N. Heess, G. Wayne, D. Silver, et al. “Learning continuous control policies by stochastic value gradients”. In: NIPS. 2015
Deterministic Policy Gradient
For Gaussian actions, variance of score function policy gradient estimator goes to
infinity as variance goes to zero
But SVG(0) gradient is fine when σ → 0
θ
t
Q(st, π(st, θ, ζt))
Problem: there’s no exploration.
Solution: add noise to the policy, but estimate Q with TD(0), so it’s valid
off-policy
Policy gradient is a little biased (even with Q = Qπ), but only because state
distribution is off—it gets the right gradient at every state
D. Silver, G. Lever, N. Heess, et al. “Deterministic policy gradient algorithms”. In: ICML. 2014
Deep Deterministic Policy Gradient
Incorporate replay buffer and target network ideas from DQN for increased
stability
T. P. Lillicrap, J. J. Hunt, A. Pritzel, et al. “Continuous control with deep reinforcement learning”. In: ICLR (2015)
Deep Deterministic Policy Gradient
Incorporate replay buffer and target network ideas from DQN for increased
stability
Use lagged (Polyak-averaging) version of Qφ and πθ for fitting Qφ (towards
Qπ,γ
) with TD(0)
ˆQt = rt + γQφ (st+1, π(st+1; θ ))
T. P. Lillicrap, J. J. Hunt, A. Pritzel, et al. “Continuous control with deep reinforcement learning”. In: ICLR (2015)
Deep Deterministic Policy Gradient
Incorporate replay buffer and target network ideas from DQN for increased
stability
Use lagged (Polyak-averaging) version of Qφ and πθ for fitting Qφ (towards
Qπ,γ
) with TD(0)
ˆQt = rt + γQφ (st+1, π(st+1; θ ))
Pseudocode:
for iteration=1, 2, . . . do
Act for several timesteps, add data to replay buffer
Sample minibatch
Update πθ using g ∝ θ
T
t=1 Q(st, π(st, zt; θ))
Update Qφ using g ∝ φ
T
t=1(Qφ(st, at) − ˆQt)2
,
end for
T. P. Lillicrap, J. J. Hunt, A. Pritzel, et al. “Continuous control with deep reinforcement learning”. In: ICLR (2015)
DDPG Results
Applied to 2D and 3D robotics tasks and driving with pixel input
T. P. Lillicrap, J. J. Hunt, A. Pritzel, et al. “Continuous control with deep reinforcement learning”. In: ICLR (2015)
Policy Gradient Methods: Comparison
Y. Duan, X. Chen, R. Houthooft, et al. “Benchmarking Deep Reinforcement Learning for Continuous Control”. In: ICML (2016)

More Related Content

What's hot

Query Modeling Using Non-relevance Information - TREC 2008 Relevance Feedbac...
Query Modeling Using Non-relevance Information  - TREC 2008 Relevance Feedbac...Query Modeling Using Non-relevance Information  - TREC 2008 Relevance Feedbac...
Query Modeling Using Non-relevance Information - TREC 2008 Relevance Feedbac...
Edgar Meij
 
CVPR2010: Advanced ITinCVPR in a Nutshell: part 6: Mixtures
CVPR2010: Advanced ITinCVPR in a Nutshell: part 6: MixturesCVPR2010: Advanced ITinCVPR in a Nutshell: part 6: Mixtures
CVPR2010: Advanced ITinCVPR in a Nutshell: part 6: Mixtures
zukun
 
Refining Bayesian Data Analysis Methods for Use with Longer Waveforms
Refining Bayesian Data Analysis Methods for Use with Longer WaveformsRefining Bayesian Data Analysis Methods for Use with Longer Waveforms
Refining Bayesian Data Analysis Methods for Use with Longer Waveforms
James Bell
 
Application of Graphic LASSO in Portfolio Optimization_Yixuan Chen & Mengxi J...
Application of Graphic LASSO in Portfolio Optimization_Yixuan Chen & Mengxi J...Application of Graphic LASSO in Portfolio Optimization_Yixuan Chen & Mengxi J...
Application of Graphic LASSO in Portfolio Optimization_Yixuan Chen & Mengxi J...
Mengxi Jiang
 

What's hot (20)

ABC with Wasserstein distances
ABC with Wasserstein distancesABC with Wasserstein distances
ABC with Wasserstein distances
 
Survadapt-Webinar_2014_SLIDES
Survadapt-Webinar_2014_SLIDESSurvadapt-Webinar_2014_SLIDES
Survadapt-Webinar_2014_SLIDES
 
QMC Program: Trends and Advances in Monte Carlo Sampling Algorithms Workshop,...
QMC Program: Trends and Advances in Monte Carlo Sampling Algorithms Workshop,...QMC Program: Trends and Advances in Monte Carlo Sampling Algorithms Workshop,...
QMC Program: Trends and Advances in Monte Carlo Sampling Algorithms Workshop,...
 
comments on exponential ergodicity of the bouncy particle sampler
comments on exponential ergodicity of the bouncy particle samplercomments on exponential ergodicity of the bouncy particle sampler
comments on exponential ergodicity of the bouncy particle sampler
 
QMC Program: Trends and Advances in Monte Carlo Sampling Algorithms Workshop,...
QMC Program: Trends and Advances in Monte Carlo Sampling Algorithms Workshop,...QMC Program: Trends and Advances in Monte Carlo Sampling Algorithms Workshop,...
QMC Program: Trends and Advances in Monte Carlo Sampling Algorithms Workshop,...
 
Coordinate sampler: A non-reversible Gibbs-like sampler
Coordinate sampler: A non-reversible Gibbs-like samplerCoordinate sampler: A non-reversible Gibbs-like sampler
Coordinate sampler: A non-reversible Gibbs-like sampler
 
prior selection for mixture estimation
prior selection for mixture estimationprior selection for mixture estimation
prior selection for mixture estimation
 
Query Modeling Using Non-relevance Information - TREC 2008 Relevance Feedbac...
Query Modeling Using Non-relevance Information  - TREC 2008 Relevance Feedbac...Query Modeling Using Non-relevance Information  - TREC 2008 Relevance Feedbac...
Query Modeling Using Non-relevance Information - TREC 2008 Relevance Feedbac...
 
ABC-Gibbs
ABC-GibbsABC-Gibbs
ABC-Gibbs
 
A new transformation into State Transition Algorithm for finding the global m...
A new transformation into State Transition Algorithm for finding the global m...A new transformation into State Transition Algorithm for finding the global m...
A new transformation into State Transition Algorithm for finding the global m...
 
CVPR2010: Advanced ITinCVPR in a Nutshell: part 6: Mixtures
CVPR2010: Advanced ITinCVPR in a Nutshell: part 6: MixturesCVPR2010: Advanced ITinCVPR in a Nutshell: part 6: Mixtures
CVPR2010: Advanced ITinCVPR in a Nutshell: part 6: Mixtures
 
Relations as Executable Specifications
Relations as Executable SpecificationsRelations as Executable Specifications
Relations as Executable Specifications
 
Nearly optimal average case complexity of counting bicliques under seth
Nearly optimal average case complexity of counting bicliques under sethNearly optimal average case complexity of counting bicliques under seth
Nearly optimal average case complexity of counting bicliques under seth
 
t-tests in R - Lab slides for UGA course FANR 6750
t-tests in R - Lab slides for UGA course FANR 6750t-tests in R - Lab slides for UGA course FANR 6750
t-tests in R - Lab slides for UGA course FANR 6750
 
Information in the Weights
Information in the WeightsInformation in the Weights
Information in the Weights
 
Refining Bayesian Data Analysis Methods for Use with Longer Waveforms
Refining Bayesian Data Analysis Methods for Use with Longer WaveformsRefining Bayesian Data Analysis Methods for Use with Longer Waveforms
Refining Bayesian Data Analysis Methods for Use with Longer Waveforms
 
Introduction to Functional Data Analysis
Introduction to Functional Data AnalysisIntroduction to Functional Data Analysis
Introduction to Functional Data Analysis
 
Application of Graphic LASSO in Portfolio Optimization_Yixuan Chen & Mengxi J...
Application of Graphic LASSO in Portfolio Optimization_Yixuan Chen & Mengxi J...Application of Graphic LASSO in Portfolio Optimization_Yixuan Chen & Mengxi J...
Application of Graphic LASSO in Portfolio Optimization_Yixuan Chen & Mengxi J...
 
PAC-Bayesian Bound for Deep Learning
PAC-Bayesian Bound for Deep LearningPAC-Bayesian Bound for Deep Learning
PAC-Bayesian Bound for Deep Learning
 
Streaming multiscale anomaly detection
Streaming multiscale anomaly detectionStreaming multiscale anomaly detection
Streaming multiscale anomaly detection
 

Similar to Lec7 deeprlbootcamp-svg+scg

Similar to Lec7 deeprlbootcamp-svg+scg (20)

Lec5 advanced-policy-gradient-methods
Lec5 advanced-policy-gradient-methodsLec5 advanced-policy-gradient-methods
Lec5 advanced-policy-gradient-methods
 
Lecture_NaturalPolicyGradientsTRPOPPO.pdf
Lecture_NaturalPolicyGradientsTRPOPPO.pdfLecture_NaturalPolicyGradientsTRPOPPO.pdf
Lecture_NaturalPolicyGradientsTRPOPPO.pdf
 
MUMS: Bayesian, Fiducial, and Frequentist Conference - Model Selection in the...
MUMS: Bayesian, Fiducial, and Frequentist Conference - Model Selection in the...MUMS: Bayesian, Fiducial, and Frequentist Conference - Model Selection in the...
MUMS: Bayesian, Fiducial, and Frequentist Conference - Model Selection in the...
 
Gradient Estimation Using Stochastic Computation Graphs
Gradient Estimation Using Stochastic Computation GraphsGradient Estimation Using Stochastic Computation Graphs
Gradient Estimation Using Stochastic Computation Graphs
 
Scalable trust-region method for deep reinforcement learning using Kronecker-...
Scalable trust-region method for deep reinforcement learning using Kronecker-...Scalable trust-region method for deep reinforcement learning using Kronecker-...
Scalable trust-region method for deep reinforcement learning using Kronecker-...
 
safe and efficient off policy reinforcement learning
safe and efficient off policy reinforcement learningsafe and efficient off policy reinforcement learning
safe and efficient off policy reinforcement learning
 
Chap 8. Optimization for training deep models
Chap 8. Optimization for training deep modelsChap 8. Optimization for training deep models
Chap 8. Optimization for training deep models
 
Andres hernandez ai_machine_learning_london_nov2017
Andres hernandez ai_machine_learning_london_nov2017Andres hernandez ai_machine_learning_london_nov2017
Andres hernandez ai_machine_learning_london_nov2017
 
DDPG algortihm for angry birds
DDPG algortihm for angry birdsDDPG algortihm for angry birds
DDPG algortihm for angry birds
 
Introduction to Reinforcement Learning for Molecular Design
Introduction to Reinforcement Learning for Molecular Design Introduction to Reinforcement Learning for Molecular Design
Introduction to Reinforcement Learning for Molecular Design
 
Deep Reinforcement Learning Through Policy Optimization, John Schulman, OpenAI
Deep Reinforcement Learning Through Policy Optimization, John Schulman, OpenAIDeep Reinforcement Learning Through Policy Optimization, John Schulman, OpenAI
Deep Reinforcement Learning Through Policy Optimization, John Schulman, OpenAI
 
Learning to discover monte carlo algorithm on spin ice manifold
Learning to discover monte carlo algorithm on spin ice manifoldLearning to discover monte carlo algorithm on spin ice manifold
Learning to discover monte carlo algorithm on spin ice manifold
 
【博士論文発表会】パラメータ制約付き特異モデルの統計的学習理論
【博士論文発表会】パラメータ制約付き特異モデルの統計的学習理論【博士論文発表会】パラメータ制約付き特異モデルの統計的学習理論
【博士論文発表会】パラメータ制約付き特異モデルの統計的学習理論
 
Multinomial Logistic Regression with Apache Spark
Multinomial Logistic Regression with Apache SparkMultinomial Logistic Regression with Apache Spark
Multinomial Logistic Regression with Apache Spark
 
Alpine Spark Implementation - Technical
Alpine Spark Implementation - TechnicalAlpine Spark Implementation - Technical
Alpine Spark Implementation - Technical
 
Unbiased Bayes for Big Data
Unbiased Bayes for Big DataUnbiased Bayes for Big Data
Unbiased Bayes for Big Data
 
Continuous control
Continuous controlContinuous control
Continuous control
 
GAN(と強化学習との関係)
GAN(と強化学習との関係)GAN(と強化学習との関係)
GAN(と強化学習との関係)
 
Logistic Regression in Case-Control Study
Logistic Regression in Case-Control StudyLogistic Regression in Case-Control Study
Logistic Regression in Case-Control Study
 
Explaining the Basics of Mean Field Variational Approximation for Statisticians
Explaining the Basics of Mean Field Variational Approximation for StatisticiansExplaining the Basics of Mean Field Variational Approximation for Statisticians
Explaining the Basics of Mean Field Variational Approximation for Statisticians
 

More from Ronald Teo

Lec2 sampling-based-approximations-and-function-fitting
Lec2 sampling-based-approximations-and-function-fittingLec2 sampling-based-approximations-and-function-fitting
Lec2 sampling-based-approximations-and-function-fitting
Ronald Teo
 

More from Ronald Teo (15)

Mc td
Mc tdMc td
Mc td
 
07 regularization
07 regularization07 regularization
07 regularization
 
Dp
DpDp
Dp
 
06 mlp
06 mlp06 mlp
06 mlp
 
Mdp
MdpMdp
Mdp
 
04 numerical
04 numerical04 numerical
04 numerical
 
Eac4f222d9d468a0c29a71a3830a5c60 c5_w3l08-attentionmodel
 Eac4f222d9d468a0c29a71a3830a5c60 c5_w3l08-attentionmodel Eac4f222d9d468a0c29a71a3830a5c60 c5_w3l08-attentionmodel
Eac4f222d9d468a0c29a71a3830a5c60 c5_w3l08-attentionmodel
 
Intro rl
Intro rlIntro rl
Intro rl
 
Lec6 nuts-and-bolts-deep-rl-research
Lec6 nuts-and-bolts-deep-rl-researchLec6 nuts-and-bolts-deep-rl-research
Lec6 nuts-and-bolts-deep-rl-research
 
Lec4b pong from_pixels
Lec4b pong from_pixelsLec4b pong from_pixels
Lec4b pong from_pixels
 
Lec4a policy-gradients-actor-critic
Lec4a policy-gradients-actor-criticLec4a policy-gradients-actor-critic
Lec4a policy-gradients-actor-critic
 
Lec3 dqn
Lec3 dqnLec3 dqn
Lec3 dqn
 
Lec2 sampling-based-approximations-and-function-fitting
Lec2 sampling-based-approximations-and-function-fittingLec2 sampling-based-approximations-and-function-fitting
Lec2 sampling-based-approximations-and-function-fitting
 
Lec1 intro-mdps-exact-methods
Lec1 intro-mdps-exact-methodsLec1 intro-mdps-exact-methods
Lec1 intro-mdps-exact-methods
 
02 linear algebra
02 linear algebra02 linear algebra
02 linear algebra
 

Recently uploaded

Why Teams call analytics are critical to your entire business
Why Teams call analytics are critical to your entire businessWhy Teams call analytics are critical to your entire business
Why Teams call analytics are critical to your entire business
panagenda
 
Architecting Cloud Native Applications
Architecting Cloud Native ApplicationsArchitecting Cloud Native Applications
Architecting Cloud Native Applications
WSO2
 
Modular Monolith - a Practical Alternative to Microservices @ Devoxx UK 2024
Modular Monolith - a Practical Alternative to Microservices @ Devoxx UK 2024Modular Monolith - a Practical Alternative to Microservices @ Devoxx UK 2024
Modular Monolith - a Practical Alternative to Microservices @ Devoxx UK 2024
Victor Rentea
 

Recently uploaded (20)

Navigating the Deluge_ Dubai Floods and the Resilience of Dubai International...
Navigating the Deluge_ Dubai Floods and the Resilience of Dubai International...Navigating the Deluge_ Dubai Floods and the Resilience of Dubai International...
Navigating the Deluge_ Dubai Floods and the Resilience of Dubai International...
 
Why Teams call analytics are critical to your entire business
Why Teams call analytics are critical to your entire businessWhy Teams call analytics are critical to your entire business
Why Teams call analytics are critical to your entire business
 
Polkadot JAM Slides - Token2049 - By Dr. Gavin Wood
Polkadot JAM Slides - Token2049 - By Dr. Gavin WoodPolkadot JAM Slides - Token2049 - By Dr. Gavin Wood
Polkadot JAM Slides - Token2049 - By Dr. Gavin Wood
 
Spring Boot vs Quarkus the ultimate battle - DevoxxUK
Spring Boot vs Quarkus the ultimate battle - DevoxxUKSpring Boot vs Quarkus the ultimate battle - DevoxxUK
Spring Boot vs Quarkus the ultimate battle - DevoxxUK
 
Architecting Cloud Native Applications
Architecting Cloud Native ApplicationsArchitecting Cloud Native Applications
Architecting Cloud Native Applications
 
Apidays New York 2024 - The Good, the Bad and the Governed by David O'Neill, ...
Apidays New York 2024 - The Good, the Bad and the Governed by David O'Neill, ...Apidays New York 2024 - The Good, the Bad and the Governed by David O'Neill, ...
Apidays New York 2024 - The Good, the Bad and the Governed by David O'Neill, ...
 
Apidays New York 2024 - The value of a flexible API Management solution for O...
Apidays New York 2024 - The value of a flexible API Management solution for O...Apidays New York 2024 - The value of a flexible API Management solution for O...
Apidays New York 2024 - The value of a flexible API Management solution for O...
 
Corporate and higher education May webinar.pptx
Corporate and higher education May webinar.pptxCorporate and higher education May webinar.pptx
Corporate and higher education May webinar.pptx
 
AXA XL - Insurer Innovation Award Americas 2024
AXA XL - Insurer Innovation Award Americas 2024AXA XL - Insurer Innovation Award Americas 2024
AXA XL - Insurer Innovation Award Americas 2024
 
FWD Group - Insurer Innovation Award 2024
FWD Group - Insurer Innovation Award 2024FWD Group - Insurer Innovation Award 2024
FWD Group - Insurer Innovation Award 2024
 
Boost Fertility New Invention Ups Success Rates.pdf
Boost Fertility New Invention Ups Success Rates.pdfBoost Fertility New Invention Ups Success Rates.pdf
Boost Fertility New Invention Ups Success Rates.pdf
 
Biography Of Angeliki Cooney | Senior Vice President Life Sciences | Albany, ...
Biography Of Angeliki Cooney | Senior Vice President Life Sciences | Albany, ...Biography Of Angeliki Cooney | Senior Vice President Life Sciences | Albany, ...
Biography Of Angeliki Cooney | Senior Vice President Life Sciences | Albany, ...
 
CNIC Information System with Pakdata Cf In Pakistan
CNIC Information System with Pakdata Cf In PakistanCNIC Information System with Pakdata Cf In Pakistan
CNIC Information System with Pakdata Cf In Pakistan
 
[BuildWithAI] Introduction to Gemini.pdf
[BuildWithAI] Introduction to Gemini.pdf[BuildWithAI] Introduction to Gemini.pdf
[BuildWithAI] Introduction to Gemini.pdf
 
Artificial Intelligence Chap.5 : Uncertainty
Artificial Intelligence Chap.5 : UncertaintyArtificial Intelligence Chap.5 : Uncertainty
Artificial Intelligence Chap.5 : Uncertainty
 
TrustArc Webinar - Unlock the Power of AI-Driven Data Discovery
TrustArc Webinar - Unlock the Power of AI-Driven Data DiscoveryTrustArc Webinar - Unlock the Power of AI-Driven Data Discovery
TrustArc Webinar - Unlock the Power of AI-Driven Data Discovery
 
Exploring Multimodal Embeddings with Milvus
Exploring Multimodal Embeddings with MilvusExploring Multimodal Embeddings with Milvus
Exploring Multimodal Embeddings with Milvus
 
Modular Monolith - a Practical Alternative to Microservices @ Devoxx UK 2024
Modular Monolith - a Practical Alternative to Microservices @ Devoxx UK 2024Modular Monolith - a Practical Alternative to Microservices @ Devoxx UK 2024
Modular Monolith - a Practical Alternative to Microservices @ Devoxx UK 2024
 
Connector Corner: Accelerate revenue generation using UiPath API-centric busi...
Connector Corner: Accelerate revenue generation using UiPath API-centric busi...Connector Corner: Accelerate revenue generation using UiPath API-centric busi...
Connector Corner: Accelerate revenue generation using UiPath API-centric busi...
 
Web Form Automation for Bonterra Impact Management (fka Social Solutions Apri...
Web Form Automation for Bonterra Impact Management (fka Social Solutions Apri...Web Form Automation for Bonterra Impact Management (fka Social Solutions Apri...
Web Form Automation for Bonterra Impact Management (fka Social Solutions Apri...
 

Lec7 deeprlbootcamp-svg+scg

  • 1. Stochastic Computation Graphs, and Pathwise Derivative Policy Gradient Methods John Schulman August 27, 2017
  • 3. Motivation Gradient estimation is at the core of much of modern machine learning Backprop is mechanical: this us more freedom to tinker with architectures and loss functions RL (and certain probabilistic models) involve components you can’t backpropagate through—need REINFORCE-style estimators logprob * advantage Often we need a combination of backprop-style and REINFORCE-style gradient estimation RNN policies π(at | ht), where ht = f (ht−1, at−1, ot) Instead of reward, we have a known differentiable function, e.g. classifier with hard attention. Objective = logprob of correct label. Each paper to propose such a model provides a new convoluted derivation
  • 4. Gradients of Expectations Want to compute θE [F]. Where’s θ? “Inside” distribution, e.g., Ex∼p(· | θ) [F(x)] θEx [f (x)] = Ex [f (x) θ log px (x; θ)] . Score function (SF) estimator Example: REINFORCE policy gradients, where x is the trajectory Inside expectation: Ez∼N(0,1) [F(θ, z)] θEz [f (x(z, θ))] = Ez [ θf (x(z, θ))] . Pathwise derivative (PD) estimator Example: neural net with dropout Often, we can reparametrize, to change from one form to another Not a property of problem: equivalent ways of writing down the same stochastic process What if F depends on θ in complicated way, affecting distribution and F? M. C. Fu. “Gradient estimation”. In: Handbooks in operations research and management science 13 (2006), pp. 575–616
  • 5. Stochastic Computation Graphs Stochastic computation graph is a DAG, each node corresponds to a deterministic or stochastic operation Can automatically derive unbiased gradient estimators, with variance reduction John Schulman1, Nicolas Heess2, Théophane Weber2, 1University of California, Berkeley 2Google Deep Generalize backpropagation to deal with random variables that we can’t differentiate through. Why can’t we differentiate through them? 1) discrete random variables 2) unmodeled external world, in RL / control Computation Graphs Stochastic Computation Graphs stochastic node Motivation / Applications policy gradients in reinforcement learning variational inference distribution is different from the distribution we are evaluating: for parameter ✓ 2 ⇥, ✓ = ✓old is used for sampling, but we are evaluating at ✓ = ✓new. Ev c | ✓new [ˆc] = Ev c | ✓old 2 6 6 4ˆc Y v c, ✓ D v Pv(v | DEPSv✓, ✓new) Pv(v | DEPSv✓, ✓old) 3 7 7 5 (17)  Ev c | ✓old 2 6 6 4ˆc 0 B B @log 0 B B @ Y v c, ✓ D v Pv(v | DEPSv✓, ✓new) Pv(v | DEPSv✓, ✓old) 1 C C A + 1 1 C C A 3 7 7 5 (18) where the second line used the inequality x log x + 1, and the sign is reversed since ˆc is negative. Summing over c 2 C and rearranging we get ES | ✓new " X c2C ˆc #  ES | ✓old " X c2C ˆc + X v2S log ✓ p(v | DEPSv✓, ✓new) p(v | DEPSv✓, ✓old) ◆ ˆQv # (19) = ES | ✓old " X v2S log p(v | DEPSv✓, ✓new) ˆQv # + const (20) Equation (20) allows for majorization-minimization algorithms (like the EM algorithm) to be used to optimize with respect to ✓. In fact, similar equations have been derived by interpreting rewards (negative costs) as probabilities, and then taking the variational lower bound on log-probability (e.g., [24]). C Examples C.1 Generalized EM Algorithm and Variational Inference. The generalized EM algorithm maximizes likelihood in a probabilistic model with latent variables [18]. Suppose the probabilistic model defines a probability distribution p(x, z; ✓) where x is ob- served, z is a latent variable, and ✓ is a parameter of the distribution. The generalized EM algorithm it’s all about gradients of expectations! Just Differ @ @✓ E 2 6 4 X cost node c c 3 7 5 = X stochastic determini influence baseline( = @ @✓ 2 6 6 6 6 6 6 4 X stochastic node x deterministically influenced by ✓ 0 B B B @ log p(x Under certain condition cost, related to variation Equivalently, use backp Overview L L Why’s it useful? 1) no need to rederive every time 2) enable generic software θ φ base = @ @✓ 2 6 6 6 6 6 6 4 X stochastic node x deterministically influenced by ✓ 0 B B B @ lo Score function estimator: @ @✓ Ex Pathwise derivative estim @ @✓ Ez @a @✓ @ log p(b | a, d) @a @ log p(b | a, d) @a @c @b @e @d @ log p(d | ) @ (c + e) @c @b @b @d @ log p(b | a, d) @d c @d @ ✓ @ log p(b | a, d) @d c + @e @d ◆ 2 @a @✓ @ log p(b | a, d) @a c @ log p(b | a, d) @a c @c @b @e @d @ log p(d | ) @ (c + e) @c @b @b @d @ log p(b | a, d) @d c @d @ ✓ @ log p(b | a, d) @d c + @e @d ◆ 2 J. Schulman, N. Heess, T. Weber, et al. “Gradient Estimation Using Stochastic Computation Graphs”. In: NIPS. 2015
  • 6. Example: Hard Attention Look at image (o1) Determine crop window (a) Get cropped image (o2) Version 1 Output label (y) Get reward of 1 for correct label Version 2 Output parameter specifying distribution over label (κ) Reward = prob or logprob of correct label (log p) o1 o2 θ a y r o1 o2 θ a κ log p
  • 7. Abstract Example a d cb e θ φ L = c + e. Want to compute d dθ E [L] and d dφ E [L]. Treat stochastic node b) as constants, and introduce losses logprob ∗ (futurecost − optionalbaseline) at each stochastic node Obtain unbiased gradient estimate by differentiating surrogate: Surrogate(θ, ψ) = c + e (1) + log p(ˆb | a, d)ˆc (2) (1): how parameters influence cost through deterministic dependencies (2): how parameters affect distribution over random variables.
  • 8. Results General method for calculating gradient estimator Surrogate loss version: each stochastic node gives a loss expression logprob(node value) * (sum of “downstream” costs treated as constants). Add these to the sum of costs treated as variables. Alternative generalized backprop: as usual, do reverse topological sort. Instead of backpropagating using Jacobian, backpropagate gradlogprob * futurecosts. Generalized version of baselines for variance reduction (without bias): logp(node) * (sum of downstream costs treated as constants minus baseline(nondescendents of node))
  • 9. Pathwise Derivative Policy Gradient Methods
  • 10. Deriving the Policy Gradient, Reparameterized Episodic MDP: θ s1 s2 . . . sT a1 a2 . . . aT RT Want to compute θE [RT ]. We’ll use θ log π(at | st; θ)
  • 11. Deriving the Policy Gradient, Reparameterized Episodic MDP: θ s1 s2 . . . sT a1 a2 . . . aT RT Want to compute θE [RT ]. We’ll use θ log π(at | st; θ) Reparameterize: at = π(st, zt; θ). zt is noise from fixed distribution. θ s1 s2 . . . sT a1 a2 . . . aT z1 z2 . . . zT RT
  • 12. Deriving the Policy Gradient, Reparameterized Episodic MDP: θ s1 s2 . . . sT a1 a2 . . . aT RT Want to compute θE [RT ]. We’ll use θ log π(at | st; θ) Reparameterize: at = π(st, zt; θ). zt is noise from fixed distribution. θ s1 s2 . . . sT a1 a2 . . . aT z1 z2 . . . zT RT Only works if P(s2 | s1, a1) is known ¨
  • 13. Using a Q-function θ s1 s2 . . . sT a1 a2 . . . aT z1 z2 . . . zT RT d dθ E [RT ] = E T t=1 dRT dat dat dθ = E T t=1 d dat E [RT | at] dat dθ = E T t=1 dQ(st, at) dat dat dθ = E T t=1 d dθ Q(st, π(st, zt; θ))
  • 14. SVG(0) Algorithm Learn Qφ to approximate Qπ,γ , and use it to compute gradient estimates. N. Heess, G. Wayne, D. Silver, et al. “Learning continuous control policies by stochastic value gradients”. In: NIPS. 2015
  • 15. SVG(0) Algorithm Learn Qφ to approximate Qπ,γ , and use it to compute gradient estimates. Pseudocode: for iteration=1, 2, . . . do Execute policy πθ to collect T timesteps of data Update πθ using g ∝ θ T t=1 Q(st, π(st, zt; θ)) Update Qφ using g ∝ φ T t=1(Qφ(st, at) − ˆQt)2 , e.g. with TD(λ) end for N. Heess, G. Wayne, D. Silver, et al. “Learning continuous control policies by stochastic value gradients”. In: NIPS. 2015
  • 16. SVG(1) Algorithm θ s1 s2 . . . sT a1 a2 . . . aT z1 z2 . . . zT RT Instead of learning Q, we learn State-value function V ≈ V π,γ Dynamics model f , approximating st+1 = f (st, at) + ζt Given transition (st, at, st+1), infer ζt = st+1 − f (st, at) Q(st, at) = E [rt + γV (st+1)] = E [rt + γV (f (st, at) + ζt)], and at = π(st, θ, ζt)
  • 17. SVG(∞) Algorithm θ s1 s2 . . . sT a1 a2 . . . aT z1 z2 . . . zT RT Just learn dynamics model f Given whole trajectory, infer all noise variables Freeze all policy and dynamics noise, differentiate through entire deterministic computation graph
  • 18. SVG Results Applied to 2D robotics tasks Overall: different gradient estimators behave similarly N. Heess, G. Wayne, D. Silver, et al. “Learning continuous control policies by stochastic value gradients”. In: NIPS. 2015
  • 19. Deterministic Policy Gradient For Gaussian actions, variance of score function policy gradient estimator goes to infinity as variance goes to zero But SVG(0) gradient is fine when σ → 0 θ t Q(st, π(st, θ, ζt)) Problem: there’s no exploration. Solution: add noise to the policy, but estimate Q with TD(0), so it’s valid off-policy Policy gradient is a little biased (even with Q = Qπ), but only because state distribution is off—it gets the right gradient at every state D. Silver, G. Lever, N. Heess, et al. “Deterministic policy gradient algorithms”. In: ICML. 2014
  • 20. Deep Deterministic Policy Gradient Incorporate replay buffer and target network ideas from DQN for increased stability T. P. Lillicrap, J. J. Hunt, A. Pritzel, et al. “Continuous control with deep reinforcement learning”. In: ICLR (2015)
  • 21. Deep Deterministic Policy Gradient Incorporate replay buffer and target network ideas from DQN for increased stability Use lagged (Polyak-averaging) version of Qφ and πθ for fitting Qφ (towards Qπ,γ ) with TD(0) ˆQt = rt + γQφ (st+1, π(st+1; θ )) T. P. Lillicrap, J. J. Hunt, A. Pritzel, et al. “Continuous control with deep reinforcement learning”. In: ICLR (2015)
  • 22. Deep Deterministic Policy Gradient Incorporate replay buffer and target network ideas from DQN for increased stability Use lagged (Polyak-averaging) version of Qφ and πθ for fitting Qφ (towards Qπ,γ ) with TD(0) ˆQt = rt + γQφ (st+1, π(st+1; θ )) Pseudocode: for iteration=1, 2, . . . do Act for several timesteps, add data to replay buffer Sample minibatch Update πθ using g ∝ θ T t=1 Q(st, π(st, zt; θ)) Update Qφ using g ∝ φ T t=1(Qφ(st, at) − ˆQt)2 , end for T. P. Lillicrap, J. J. Hunt, A. Pritzel, et al. “Continuous control with deep reinforcement learning”. In: ICLR (2015)
  • 23. DDPG Results Applied to 2D and 3D robotics tasks and driving with pixel input T. P. Lillicrap, J. J. Hunt, A. Pritzel, et al. “Continuous control with deep reinforcement learning”. In: ICLR (2015)
  • 24. Policy Gradient Methods: Comparison Y. Duan, X. Chen, R. Houthooft, et al. “Benchmarking Deep Reinforcement Learning for Continuous Control”. In: ICML (2016)