Overconfidence and subnetwork Inference for BNNs, based on:
"Being Bayesian, Even Just a Bit, Fixes Overconfidence in ReLU Networks" and "Bayesian Deep Learning via Subnetwork Inference"
Mastering MySQL Database Architecture: Deep Dive into MySQL Shell and MySQL R...
Overconfidence and subnetwork Inference for BNNs
1. Confidence control and
subnetwork inference for
Bayesian Neural Networks
Tomasz Kuśmierczyk
2022-04-29
Based on:
Kristiadi et al.: Being Bayesian, Even Just a Bit, Fixes Overconfidence in ReLU Networks, ICML 2020
Daxberger et al.: Bayesian Deep Learning via Subnetwork Inference, ICML 2021
5. Reminder: Bayesian inference and learning
● we can sample from posterior using MCMC (HMC is the gold standard)
network
weights
6. Approximate (distributional) inference
1. What q we choose?
e.g. factorized Gaussian, mixture of gaussians
2. How are we going to fit its parameters?
e.g. VI, Laplace
7. The predictive distribution for the binary case
Gaussian approximation:
Linearization around mean:
Probit approximation:
Assuming:
8. The predictive distribution for the binary case
Gaussian approximation:
Linearization around mean:
Probit approximation:
Assuming:
9. The predictive distribution for the binary case
Gaussian approximation:
Assuming:
Laplace fit: with prior:
full generalized-Gauss-Newton (GGN) approximation of the Hessian
10. One can control confidence by controlling approximation
parameters
smallest singular value of smallest eigenvalue
a vector depending only on
approximation parameters
11. (Laplace) Last-layer approximation (LLLA)
Gaussian approximation for the last layer:
Binary linear classifier:
Binary linear classifier:
15. Subnetwork selection
taking Wasserstain distance, assuming factorized (independence between weights) Gaussian
posterior (for both arguments), linearized model, and Laplace fit with GGN approximation:
Split weights into point-wise and distributional:
Select a subset S minimizing a distance to the posterior predictive:
d-th entry of inverse of
(approximated) Hessian
1 iff selected;
0 otherwise