Call for Papers - International Journal of Intelligent Systems and Applicatio...
ย
S. Kim, ICLR 2023, MLILAB, KAISTAI
1. Scale-invariant Bayesian Neural Networks
with Connectivity Tangent Kernel
Sungyub Kim, Sihwan Park, Kyungsu Kim, Eunho Yang
ICLR 2023
Graduate School of AI, KAIST
2. Background
โข Data-dependent PAC-Bayes bound
โข PAC-Bayes bound: Generalization gap is bounded by KL divergence between prior and posterior
โข Data-dependent PAC-Bayes bound: Use data-dependent PAC-Bayes prior
โข Partition training ๐ฎ into ๐ฎโ and ๐ฎโ
โข Pre-train a โ๐ with ๐ฎโ (Therefore, โ๐ and ๐ฎโ are independent.)
โข Fine-tuning โ with entire dataset ๐ฎ.
<latexit sha1_base64="0fOhK+Uw2PjQWa3dwlm09D4NvWY=">AAACi3icdVFdaxNBFJ1dq43RatTHvgwNQooQd2O1RXwo1EKhpaRo2kJ2CbOTu+nQ2Q9n7gphnD/jT/LNf9PZTUo11QsD5557DjP3TFJKoTEIfnv+g7WHj9Zbj9tPnm48e9558fJcF5XiMOKFLNRlwjRIkcMIBUq4LBWwLJFwkVwf1POL76C0KPKvOC8hztgsF6ngDB016fyMMoZXKjOglJ2YpuNMms/W9pomScyZ3aaRBPof6ZcV6Rsa6W8KTZQqxs2t6fjEju9U0Y9bPLRx7ZDFrDdY+E7t22gKEtm2NYNTayedbtAPmqL3QbgEXbKs4aTzK5oWvMogRy6Z1uMwKDE2TKHgEmw7qjSUjF+zGYwdzFkGOjZNlpa+dsyUpoVyJ0fasH86DMu0nmeJU9Yb6NVZTf5rNq4w3YuNyMsKIeeLi9JKUixo/TF0KhRwlHMHGFfCvZXyK+YiRPd9bRdCuLryfXA+6Icf+u/Pdrr7wTKOFtkkW6RHQrJL9skRGZIR4V7L63u73p6/4b/zP/qfFlLfW3pekb/KP7wBSynIyQ==</latexit>
errD(Q) ๏ฃฟ errS(Q) +
s
KL[QkP] + log(2
p
N/ )
2N
<latexit sha1_base64="B+qECs1sRNnoMn188cNLIEIWGe4=">AAACv3icdVHfa9swEJbdbmuzX+n2uBfRsJEyyOywtXvYQ8v2MFgpKVvaQmyMrJxTEVl2JbkQVP2Texj0v5nsuLRptwPBd9/dd3e6S0vOlA6Ca89fW3/0+MnGZufps+cvXna3Xp2oopIUxrTghTxLiQLOBIw10xzOSgkkTzmcpvOvdfz0EqRihfilFyXEOZkJljFKtKOS7p8oJ/pc5gaktIlpPEq4+WZtv3HS1BzbHdzBzt5FHPB/BD9tcpu/qn2PI3UhtYkySai50f84tJPbrOjqBo/aQssp4lrNi1l/uKxxdLfLh2gKXJMda4YrvMVJtxcMgsbwQxC2oIdaGyXd39G0oFUOQlNOlJqEQaljQ6RmlIPtRJWCktA5mcHEQUFyULFp9m/xW8dMcVZI94TGDXtXYUiu1CJPXWY9pLofq8l/xSaVzj7Hhomy0iDoslFWcawLXB8TT5kEqvnCAUIlc7Niek7clrU7ecctIbz/5YfgZDgIdwefjj/29oN2HRvoDdpGfRSiPbSPvqMRGiPqffFSb+5x/8Cf+cIvl6m+12peoxXzF38Bxd7dyA==</latexit>
errD(Q) ๏ฃฟ errSQ
(Q) +
s
KL[QkPD] + log(2
p
NQ/ )
2NQ
Independent
3. Background
โข Invariance of generalization bounds for function-preserving transformations
โข Function-preserving scaling transformations (FPST)
โข Scale transformation (diagonal matrix) of parameters that preserves function
โข Practical examples
โข 1) Rescaling transformation of positive homogeneous (e.g., ReLU) NNs
โข 2) Weight decay (WD) with Batch Normalization (BN) layers
<latexit sha1_base64="Fwg+qLM54Y+yNbGsKfO00TITrOE=">AAACT3icbVFbS8MwGE3rfd6qPvoSHMIGY7TiDUQQ9MHHKc4N1jrSLN2CaVqSVByl/9AXffNv+OKDIqZbEd38IOTknPORLyd+zKhUtv1qmDOzc/MLi0ul5ZXVtXVrY/NWRonApIkjFom2jyRhlJOmooqRdiwICn1GWv79ea63HoiQNOI3ahgTL0R9TgOKkdJU1wqCymMNuiFSA4xYepNV3FjSahWewrGSn/R2At0gEogx+Ahdyscdvp9eZ3fpRVb7UXP/lKGRda2yXbdHBaeBU4AyKKrRtV7cXoSTkHCFGZKy49ix8lIkFMWMZCU3kSRG+B71SUdDjkIivXSURwZ3NdODeiK9uIIj9ndHikIph6GvnfmUclLLyf+0TqKCYy+lPE4U4Xh8UZAwqCKYhwt7VBCs2FADhAXVs0I8QAJhpb+gpENwJp88DW736s5h/eBqv3y2V8SxCLbBDqgABxyBM3AJGqAJMHgCb+ADfBrPxrvxZRZW0yjAFvhT5tI3dOKx7g==</latexit>
f(x, T ( )) = f(x, ), 8x 2 RD
, 8 2 RP
<latexit sha1_base64="h6rSnbDr/lAOHe02dX1jCp5081Y=">AAADpXicnVLbbhMxEN1kuZRwS+GRl1ETIJVCSBA3ISFV4qUvSG3VpJXiaOV1ZhMrXntZz1aE1X4Zf8Ebf4N3k0BoEUjMg3V0xmOfOTNhoqSlfv97re5fu37j5s6txu07d+/db+4+GFmTpQKHwiiTnofcopIahyRJ4XmSIo9DhWfh4kOZP7vA1EqjT2mZ4CTmMy0jKTg5KtitfQUXLDKGtCG08gvmHRZzmguu8pMiyNmMxzHvqu6i6DCaI/H9/SCXBbyHRlUb4kzqXDgRtlgxVQUwMTUEq5JAPukyws+Ug4ygvSGBSd0Glic85XEPbBZaJBBGaxQk9Qy4BamTjACnM7RABtqL9jOaA3f5i6oH4ARtVZGKLzFlRcHYSsfml+cbSX8R8Q8VJqP/kMEUD1Hl+OlditYZ6h4rtqU5G39Kikz6S1PbtQ0Og3FHCtvmop6uzS6CZqvf61cBV8FgDVreOo6C5jc2NSKLUZNQ3NrxoJ/QJOcpSaGwaLDMYsLFgs9w7KDmMdpJXm1ZAY8dM61kRkaX/jh2uyLnsbXLOHQ3y/2xl3Ml+afcOKPo7SSv5oxarD6KMlXaXK4sTGXqBqGWDnCRSqcVxNzNSpBb7IYzYXC55atg9KI3eN17dfyyddBf27HjPfL2vI438N54B96hd+QNPVHfqx/Wj+sn/lP/o3/qj1ZX67V1zUPvt/CDHwILJTc=</latexit>
(R ,l,k(โ))i =
8
<
:
ยท โi , if โi 2 {param. subset connecting as input edges to k-th activation at l-th layer}
โi/ , if โi 2 {param. subset connecting as output edges to k-th activation at l-th layer}
โi , for โi in the other cases
<latexit sha1_base64="dhL2SMsbctMdhmwrcgOMZoaTEBY=">AAADFnicbVLLjtMwFE3CayiP6cCSzRUNaEYqVYp4bZBGYsNyEHRmpLqKHOcmterYUXwzokT5Cjb8ChsWIMQWseNvcNpKDFPuwjo6x76Pc52USlqKot9+cOnylavXdq73bty8dXu3v3fn2Jq6EjgRRpnqNOEWldQ4IUkKT8sKeZEoPEkWrzr95AwrK41+R8sSZwXPtcyk4OSoeM8fMoUZ7bOC01xw1bxt44blvCj4UA0X7T6jORI/YJXM53QQN7KFl9ADFywzhrQhtPIDNizBXOpGuF5su9ZXWeJm0QITqSFYp+pSPBwywvfUgMwg3NASmNQhsKbkFS9GYOvEIoEwWqMgqXPgFqQuawJMc7RABsJF+IjmwJ1+thoIOEGoVqTiS6xY2zK27mareGaqv8VDlxocBuOOCs6PgTrdjNXG/UE0ilYB22C8AQNvE0dx/xdLjagL1CQUt3Y6jkqaNbwiKRS2PVZbLLlY8BynDmpeoJ01q7W28MAxKXRtZkZ3Tjj2/IuGF9Yui8Td7LZnL2od+T9tWlP2YtasvEQt1oWyWnWOdn8EUlk5y9XSAS4q6XoFMXdbEeR+Us+ZML448jY4fjwaPxs9ffNkcBht7Njx7nn3vX1v7D33Dr3X3pE38YT/0f/sf/W/BZ+CL8H34Mf6auBv3tz1/ong5x/KQ/j2</latexit>
(S ,l,k(โ))i =
โข
k ยท โi , if โi 2 {param. subset connecting as input edges to k-th activation at l-th layer}
โi , for โi in the other cases
4. Background
โข Invariance of generalization bounds for function-preserving transformations
โข Existing sharpness metric are vulnerable to FPSTs.
โข Dinh et al. (2017) showed the vulnerability of sharpness to rescaling transformations.
โข Rescaling two successive layers can arbitrarily sharpens NNs with preserving functions.
โข Li et al. (2018) showed the vulnerability of sharpness to weight decay.
โข WD improves generalizability, but sharpens NNs.
โข Prior works focused only on the scale-invariance of sharpness metrics (and not generalization bounds)
โข G-bounds of Tsuzuku et al. (2020) and Kwon et al. (2021) include scale-dependent terms.
โข G-bounds of Petzka et al. (2021) only holds for single-layer NNs.
5. Key Idea โ Modification of Jacobian
โข Many sharpness metrics / generalization bounds are based on the Jacobian.
โข Hessian, Fisher, and Gauss-Newton (GN) matrix: (Jacobian) (Jacobian)T
โข Neural Tangent Kernel (NTK): (Jacobian)T (Jacobian)
โข However, Jacobian w.r.t. parameter is not invariant to FPST.
โข To mitigate this issue, we consider Jacobian w.r.t. connectivity.
โข Note that we relax the binary constraints of connectivity for differentiable formulation.
<latexit sha1_base64="vXiphMQLqBd+PseLjGZH3czZm2U=">AAACInicbVDLSsNAFJ3UV62vqks3g0VwVRLxuRAKLhRXFewDmhgm00k7dPJg5kYoId/ixl9x40JRV4If47SNoLUHBs6ccy/33uPFgiswzU+jMDe/sLhUXC6trK6tb5Q3t5oqSiRlDRqJSLY9opjgIWsAB8HasWQk8ARreYOLkd+6Z1LxKLyFYcycgPRC7nNKQEtu+cwOCPQ9P73M8Dn++VxnbmpDnwHJ7myI4lmGW66YVXMM/J9YOamgHHW3/G53I5oELAQqiFIdy4zBSYkETgXLSnaiWEzogPRYR9OQBEw56fjEDO9ppYv9SOoXAh6rvztSEig1DDxdOVpVTXsjcZbXScA/dVIexgmwkE4G+YnAEOFRXrjLJaMghpoQKrneFdM+kYSCTrWkQ7CmT/5PmgdV67h6dHNYqZl5HEW0g3bRPrLQCaqhK1RHDUTRA3pCL+jVeDSejTfjY1JaMPKebfQHxtc3J1ulRg==</latexit>
G = J>
โ Jโ
<latexit sha1_base64="IY4xUPeQmZnTDQb6UKTXm1eWOA4=">AAACHnicbZDLSsNAFIYn9VbrLerSzWARXJVErLoRCm7EVYXeoKlhMp20QycXZk6EEvIkbnwVNy4UEVzp2zhpu9C2Pwz8fOcc5pzfiwVXYFk/RmFldW19o7hZ2tre2d0z9w9aKkokZU0aiUh2PKKY4CFrAgfBOrFkJPAEa3ujm7zefmRS8ShswDhmvYAMQu5zSkAj16w6jSEDgq+xExAYen56l7mpAznMlrEHB6LYNctWxZoILxp7Zspoprprfjn9iCYBC4EKolTXtmLopUQCp4JlJSdRLCZ0RAasq21IAqZ66eS8DJ9o0sd+JPULAU/o34mUBEqNA0935vuq+VoOl9W6CfhXvZSHcQIspNOP/ERgiHCeFe5zySiIsTaESq53xXRIJKGgEy3pEOz5kxdN66xiX1Sq9+flmjWLo4iO0DE6RTa6RDV0i+qoiSh6Qi/oDb0bz8ar8WF8TlsLxmzmEP2T8f0Lkw+jYw==</latexit>
โฅ = JโJ>
โ
<latexit sha1_base64="AQKp6vcNEGO5ZVlY0MuB1AjR8Kc=">AAACc3icfVFNS+RAEO1Ed1ezX6MLXjzYOCvMwDok4tdFEPayeFJwVJiMQ6en4jR2OrG7Ig4hf2B/njf/hRfvdjJzcFR80PB49R5VXRVlUhj0/QfHnZv/9PnLwqL39dv3Hz8bS8tnJs01hy5PZaovImZACgVdFCjhItPAkkjCeXT9t6qf34I2IlWnOM6gn7ArJWLBGVpp0PgfJgxHUVwclYOi5pzJ4rRshTgCZO2ydfeHzuiZEe029Q7oTLJ2T8yV4WXkstgMSuqFCm4+zAwaTb/j16BvSTAlTTLF8aBxHw5TniegkEtmTC/wM+wXTKPgEkovzA1kjF+zK+hZqlgCpl/UOyvphlWGNE61fQpprb5MFCwxZpxE1lnNbF7XKvG9Wi/HeL9fCJXlCIpPGsW5pJjS6gB0KDRwlGNLGNfCzkr5iGnG0Z7Js0sIXn/5LTnb6gS7nZ2T7eahP13HAlkl66RFArJHDsk/cky6hJNHZ8VZc6jz5K666+7vidV1pplfZAbu5jOAT709</latexit>
JT (โ)(x, T ( )) = Jโ(x, )T 1
6= Jโ(x, )
<latexit sha1_base64="x+hxCHN9Q4n4VMW0Idf/x1ESwk8=">AAADGXichVJNb9QwEHXCR8vytYUjF4sVaFeCJamAwqFSpV4Qp0XqtpXWS+Q4zq5VJ47sCerK5G9w4a9w4QCqOMKJf4OTXdCSUjGSpac3b97M2I4LKQwEwU/Pv3T5ytWNzWud6zdu3rrd3bpzaFSpGR8zJZU+jqnhUuR8DAIkPy40p1ks+VF8sl/nj95xbYTKD2BR8GlGZ7lIBaPgqGjLe0IkT2FIMgrzOLWvq8iyqn/6CDcMo9IeVH1SGDHARCUKMBsQLWZzeO+Eu7/LwioaVZ2Hu3jdh8CcA73AbNBQOrOJoDNHtgWuHXGGF/k5CR6s2761j8NqncCtBssVSDNke2WG/5j+b8mo2wuGQRP4PAhXoIdWMYq630miWJnxHJikxkzCoICppRoEk7zqkNLwgrITOuMTB3OacTO1zctW+IFjEpwq7U7uxqrZ9QpLM2MWWeyU9ZCmnavJf+UmJaQvplbkRQk8Z8tGaSkxKFx/E5wIzRnIhQOUaeFmxWxONWXgPlPHXULYXvk8ONwehs+Hz9487e29XF3HJrqH7qM+CtEO2kOv0AiNEfM+eJ+8L95X/6P/2T/zvy2lvrequYv+Cv/HL3EB/NA=</latexit>
Jc(x, T ( ) c)|c=1P
= Jโ(x, T ( ))diag(T ( ))
= Jโ(x, )T 1
T diag( )
= Jc(x, c)|c=1P
6. Key Idea โ Design of PAC-Bayes prior and posterior
โข Apply Bayesian Linear Regression for PAC-Bayes prior and posterior
โข PAC-Bayes prior: Isotropic Gaussian distribution centered at pre-trained parameter
โข PAC-Bayes posterior: Posterior of Bayesian Linear Regression given PAC-Bayes prior
where
โข Check Appendix D for the detailed derivation!
<latexit sha1_base64="g/Gk9BrpTL8pJpLMrkk734bRPLI=">AAACTHicbVDLSgNBEJyN7/iKevQyGAQjIewGXwhCwIsniWBUyCahdzJJBmcfzPQKYc0HevHgza/w4kERwUmygkYLBoqqarqnvEgKjbb9bGWmpmdm5+YXsotLyyurubX1Kx3GivEaC2WobjzQXIqA11Cg5DeR4uB7kl97t6dD//qOKy3C4BL7EW/40A1ERzBAI7VyzPUBe56XVAetxMUeR2juDnbcSIvC8QkduQxkcj7WqFu8d4v0O2gYyKgHzfI4qfykLaBrsmmgYJxCK5e3S/YI9C9xUpInKaqt3JPbDlns8wCZBK3rjh1hIwGFgkk+yLqx5hGwW+jyuqEB+Fw3klEZA7ptlDbthMq8AOlI/TmRgK913/dMcniynvSG4n9ePcbOUSMRQRQjD9h4USeWFEM6bJa2heIMZd8QYEqYWynrgQKGpv+sKcGZ/PJfclUuOQel/Yu9fKWc1jFPNskW2SEOOSQVckaqpEYYeSAv5I28W4/Wq/VhfY6jGSud2SC/kJn9AuHSsn0=</latexit>
Pโโค ( ) := N( | โโค
, โต2
diag(โโค
)2
)
<latexit sha1_base64="SX+3k5IFOadRyhMNgI3eBmmlmGQ=">AAAChnicdVFdT9swFHXCProwRsce92KtQqIwVQkbAx6QKu1lTxMICkhNqW4ct7Ww48i+mVSF/JT9Kd74N7hphgZsV7J0fM6519f3JrkUFsPwzvNXXrx89br1Jlh9u/Zuvf1+49zqwjA+YFpqc5mA5VJkfIACJb/MDQeVSH6RXH9f6Be/uLFCZ2c4z/lIwTQTE8EAHTVu/44V4CxJypNqXMY44whX29VWnFvRDY6CWmUgy59L7ob+8dCdBxjrVCONVeEqPFSrPtP6YlSZCpi69MbdjU/FVMEj73+c3XG7E/bCOuhzEDWgQ5o4Hrdv41SzQvEMmQRrh1GY46gEg4JJXgVxYXkO7BqmfOhgBorbUVmPsaKbjknpRBt3MqQ1+3dGCcrauUqcc9GvfaotyH9pwwInB6NSZHmBPGPLhyaFpKjpYic0FYYzlHMHgBnheqVsBgYYus0FbgjR0y8/B+e7vehbb+/ka6d/2IyjRT6ST2SLRGSf9MkPckwGhHkrXtfb9b74Lb/n7/n7S6vvNTkfyKPw+/ejvMRM</latexit>
Qโโค ( ) = N( |โโค
+ โโค
ยตQ, diag(โโค
)โQdiag(โโค
))
<latexit sha1_base64="D4tgLCVI3IiStfsy97dYd1AqkX8=">AAAEIHicpVPLbhMxFHVneJTwaFqWbCwiUAptlKloeUhIldgAq1SQNijORB7HM7Hqecj2IEWWP4UNv8KGBQjBDr4Gz0OQRysQXGmk43Ov7z33aBxknEnV7X5fc9wLFy9dXr/SuHrt+o2N5ubWsUxzQWifpDwVgwBLyllC+4opTgeZoDgOOD0JTp8V+ZO3VEiWJq/VLKOjGEcJCxnBylLjTecAxflYoxiraRDoI2Pg3SdPIQoFJhq9YlGMF7IVDPVLM9bE+EilGUSchqoNyxTBXL8xcBeG7V/ngdlBakoV9u9tI8Giqdo2Gsmit79n4J+miVhPGI5M+3eTOREVaXxdSDH/JwXtINQ4Q0ZlSd260loreGEV9IxtgXk2Lbe5v1SxYNQCN+9BrcXXu15pyD/M+jubVnTUCXhOj/NEjputbqdbBlwFXg1aoI7euPkNTVKSxzRRhGMph143UyONhWKEU9NAuaQZJqc4okMLExxTOdLlD27gHctMYJgK+yUKluz8DY1jKWdxYCuLNeRyriDPyg1zFT4aaZZkuaIJqQaFOYcqhcVrgRMmKFF8ZgEmglmtkEyx9VzZN9WwJnjLK6+C472Od9DZP3rQOnxc27EOboHboA088BAcguegB/qAOO+cD84n57P73v3ofnG/VqXOWn3nJlgI98dPjaJlNA==</latexit>
ยตQ :=
โQJ>
c (Y f(X, โโค
))
2
=
โQdiag(โโค
)J>
โ (Y f(X, โโค
))
2
โQ :=
โ
IP
โต2
+
J>
c Jc
2
โ 1
=
โ
IP
โต2
+
diag(โโค
)J>
โ Jโdiag(โโค
)
2
โ 1
7. Theoretical results
โข PAC-Bayes-CTK and its invariance
โข Perturbation term relates to the Complexity Measure of Data (CMD) term of Arora et al. (2019).
โข Sharpness term is positively correlated to the eigenvalues of CTK.
โข Each term in PAC-Bayes-CTK is invariant to FPST.
9. Empirical results
โข Robustness to the selection of prior scale.
โข Standard Linearized Laplace (LL) is sensitive to the prior scale (๐ผ) and has different optimal sharpness
for each metric.
โข Our PAC-Bayes posterior, called Connecitivity Laplace (CL), was robust to the prior scale and consistent
between metrics.
10. Conclusion
โข Our contribution is threefold:
โข We introduced a novel PAC-Bayes bound guarantees invariance for FPSTs with a broad class of
networks. We empirically verify this bound gives tight results for ResNet with 11M parameters.
โข Based on the sharpness term of our bound, we provided a low-complexity sharpness metric, called
Connectivity Sharpness.
โข To prevent overconfident predictions, we showed how our PAC-Bayes posterior can be used to solve
pitfalls of WD with BNs, proving its practicality.