Pang Wei Koh and Percy Liang
"Understanding Black-Box prediction via influence functions" ICML 2017 Best paper
References:
https://youtu.be/0w9fLX_T6tY
https://arxiv.org/abs/1703.04730
3. Introduction
A key question often asked of machine learning systems is
“Why did the system make this prediction?”
How can we explain where the model came from?
In this paper, we tackle this question by tracing a model’s predictions
through its learning algorithm and back to the training data, where the
model parameters ultimately derive from.
4. Introduction
Answering this question by perturbing the data and retraining the model
can be prohibitively expensive. To overcome this problem, we use
influence functions, a classic technique from robust statistics (Cook &
Weisberg, 1980) that tells us how the model parameters change as we
upweight a training point by an infinitesimal amount.
6. Approach
We are given training points 𝑧1,… , 𝑧 𝑛, where 𝑧𝑖 = (𝑥𝑖, 𝑦𝑖) ∈ X × Y. For
a point 𝑧 and parameters 𝜃 ∈ Θ, let 𝐿(𝑧, 𝜃) be the loss
Assume that the empirical risk is twice-differentiable and strictly
convex in 𝜃
7. Approach
Model param. by training w/o z :
Model param. by upweighting z :
Model param. by perturbing z :
8. Approach
Let us begin by studying the change in model parameters due to
removing a point z from the training set.
Formally, this change is 𝜃ɛ, 𝑧 − 𝜃
Formally, this change is 𝜃−𝑧 − 𝜃
Formally, this change is 𝜃ɛ, 𝑧 𝛿, −𝑧 − 𝜃
11. Up, params influence
where 𝐻𝜃 ≝
1
𝑛
σ𝑖=1
𝑛
∇ 𝜃
2
𝐿(𝑧, 𝜃) is the Hessian and is positive definite
(PD) by assumption. In essence, we form a quadratic approximation
to the empirical risk around 𝜃 and take a single Newton step; see
appendix A for a derivation. Since removing a point z is the same as
upweighting it by ε = −
1
𝑛
, we can linearly approximate the parameter
change due to removing z by computing 𝜃−𝑧 − 𝜃 ≈ −
1
𝑛
𝜤 𝑢𝑝,𝑝𝑎𝑟𝑎𝑚𝑠,
without retraining the model.
13. Perturbing a training input
For a training point 𝑧 = (𝑥, 𝑦) , define 𝑧 𝛿 ≝ (𝑥 + 𝛿, 𝑦). Consider the
perturbation 𝑧 → 𝑧 𝛿 , and let 𝜃 𝑧 𝛿, −𝑧 be the empirical risk minimizer
on the training points with 𝑧 𝛿 in place of 𝑧. To approximate its
effects, define the parameters resulting from moving ɛ mass from 𝑧
onto 𝑧 𝛿
14. Perturbing a training input
If x is continuous and 𝛿is small
lim
ℎ→0
F(X+h) – F(X) = F’(X)∗h
16. Efficiently calculation
We discuss two techniques for approximating 𝑠𝑡𝑒𝑠𝑡, both relying on
the fact that the HVP of a single term in 𝐻𝜃, [∇ 𝜃
2
𝐿(𝑧, 𝜃)]v, can be
computed for arbitrary v in the same time that∇ 𝜃 𝐿(𝑧, 𝜃) would take,
which is typically O(p) (Pearlmutter, 1994).
𝑠𝑡𝑒𝑠𝑡 ≝ 𝐻𝜃
−1
∇ 𝜃 𝐿(𝑧𝑡𝑒𝑠𝑡, 𝜃)
17. Efficiently calculation - Conjugate gradients (CG)
Since 𝐻𝜃 ≻ 0 by assumption, 𝐻𝜃
−1
𝑣 ≡ 𝑎𝑟𝑔𝑚𝑖𝑛 𝑡
1
2
𝑡 𝑇 𝐻𝜃 𝑡 − 𝑣 𝑇
𝑡 . We
can solve this with CG approaches that only require the evaluation of
𝐻𝜃 𝑡 , which takes O(np)time, without explicitly forming 𝐻𝜃
𝑠𝑡𝑒𝑠𝑡 ≝ 𝐻𝜃
−1
∇ 𝜃 𝐿(𝑧𝑡𝑒𝑠𝑡, 𝜃)
18. Efficiently calculation - Stochastic estimation
𝑠𝑡𝑒𝑠𝑡 ≝ 𝐻𝜃
−1
∇ 𝜃 𝐿(𝑧𝑡𝑒𝑠𝑡, 𝜃)
Dropping the 𝜃 subscript for clarity,let 𝐻𝑗
−1
≝ σ𝑖=0
𝑗
(𝐼 − 𝐻)𝑖, the first
j terms in the Taylor expansion of 𝐻−1. Rewrite this recursively as
𝐻𝑗
−1
= 𝐼 + (𝐼 − 𝐻)𝐻𝑗−1
−1
. From the validity of the Taylor expansion,
𝐻𝑗
−1
→ 𝐻−1 as 𝑗 → ∞. The key is that at each iteration, we can
substitute the full 𝐻 with a draw from any unbiased (and faster to-
compute) estimator of 𝐻 to form ෪𝐻𝑗. Since E[෪𝐻𝑗
−1
] = 𝐻𝑗
−1
, we still
have E[෪𝐻𝑗
−1
] → 𝐻−1
19. Efficiently calculation - Stochastic estimation
෪𝐻𝑗
−1
𝑣 = 𝑣 + (𝐼 − ∇ 𝜃
2
𝐿(𝑧𝑠 𝑗
, 𝜃))෫𝐻𝑗−1
−1
𝑣
Empirically, we found this significantly faster than CG.
20. Non-convexity and non-convergence
Our approach is to form a convex quadratic approximation of the loss
around ෩𝜃 , i.e., ෩𝐿 𝑧, 𝜃 = 𝐿(𝑧, ෩𝜃 ) + ∇𝐿(𝑧, ෩𝜃 ) 𝑇 𝜃 − ෩𝜃 +
1
2
(𝜃 − ෩𝜃 ) 𝑇൫
൯
𝐻෩𝜃 +
λ 𝐼 𝜃 − ෩𝜃 . Here, λ is a damping term that we add if 𝐻෩𝜃 has negative
eigenvalues; this corresponds to adding L2 regularization on 𝜃. We then
calculate 𝜤 𝑢𝑝,𝑙𝑜𝑠𝑠 using ෩𝐿 . If ෩𝜃 is close to a local minimum, this is
correlated with the result of taking a Newton step from ෩𝜃 after removing 𝜀
weight from z
Let 𝑋 ∈ 𝑅 𝑚×𝑚 be a symmetric matrix.
𝑋 = 𝑈Σ𝑈 𝑇
𝐼 = 𝑈𝐼𝑈 𝑇
𝑋 + 𝐼 = 𝑈(Σ + 𝐼)𝑈 𝑇
24. Applications - Understanding model behavior
Influence functions reveal insights about
how models rely on and extrapolate from the training data.
Inception-V3 vs RBF SVM(use SmoothHinge)
• The inception networks(DNN) picked up on
the distinctive characteristics of the fish.
• RBF SVM pattern-matched training images
superficially
29. Application - Debugging domain mismatch
If a model makes a mistake, can we find out why?
Original Modified
~20k -> ~20k
21 -> 1
3 -> 3
same
-20
same
Domain mismatch — where the training distribution
does not match the test distribution — can cause
models with high training accuracy to do poorly on
test data
(………………)
we predicted whether a patient would be readmitted
to hospital. We used logistic regression to predict
readmission with a balanced training dataset of 20K
diabetic patients from 100+ US hospitals, each
represented by127 features.
(………………)
This caused the model to wrongly classify many
children in the test set
Healthy +
re-admitted
Adults
Healthy
children
Re-admitted
children
30. Application - Debugging domain mismatch
True test label: Healthy children
Model predicts: Re-admitted childeren
0.1
0
-0.1
Influence
Top 20 influential training examples
32. Application - Fixing mislabeled examples
Training labels are noisy, and we have a small budget to manually inspect them
Can we prioritize which labels to try to fix?
Even if a human expert could
recognize wrongly labeled
examples, it is impossible in many
applications to manually review
all of the training data We show
that influence functions can help
human experts prioritize their
attention, allowing them to
inspect only the examples that
actually matter
Ham SpamSpamSpamHam
Ham SpamSpamHamSpam
We flipped the labels of a random 10% of the training data
33. Application - Fixing mislabeled examples
Plots of how test accuracy (left) and the fraction of flipped data
detected (right) change with the fraction of train data checked
35. References
Pang Wei Koh and Percy Liang. "Understanding Black-Box prediction via influence functions" ICML 2017 Best
paper
Paper link: https://arxiv.org/abs/1703.04730
Microsoft Research: Understanding Black-box Predictions via Influence Functions (by Pang Wei Koh)
Youtube: https://youtu.be/0w9fLX_T6tY