eng
Schloss Dagstuhl – Leibniz-Zentrum für Informatik
Leibniz International Proceedings in Informatics
1868-8969
2021-02-04
63:1
63:15
10.4230/LIPIcs.ITCS.2021.63
article
Training (Overparametrized) Neural Networks in Near-Linear Time
van den Brand, Jan
1
Peng, Binghui
2
Song, Zhao
3
Weinstein, Omri
2
KTH Royal Institute of Technology, Stockholm, Sweden
Columbia University, New York, NY, USA
Princeton University and Institute for Advanced Study, NJ, USA
The slow convergence rate and pathological curvature issues of first-order gradient methods for training deep neural networks, initiated an ongoing effort for developing faster second-order optimization algorithms beyond SGD, without compromising the generalization error. Despite their remarkable convergence rate (independent of the training batch size n), second-order algorithms incur a daunting slowdown in the cost per iteration (inverting the Hessian matrix of the loss function), which renders them impractical. Very recently, this computational overhead was mitigated by the works of [Zhang et al., 2019; Cai et al., 2019], yielding an O(mn²)-time second-order algorithm for training two-layer overparametrized neural networks of polynomial width m.
We show how to speed up the algorithm of [Cai et al., 2019], achieving an Õ(mn)-time backpropagation algorithm for training (mildly overparametrized) ReLU networks, which is near-linear in the dimension (mn) of the full gradient (Jacobian) matrix. The centerpiece of our algorithm is to reformulate the Gauss-Newton iteration as an 𝓁₂-regression problem, and then use a Fast-JL type dimension reduction to precondition the underlying Gram matrix in time independent of M, allowing to find a sufficiently good approximate solution via first-order conjugate gradient. Our result provides a proof-of-concept that advanced machinery from randomized linear algebra - which led to recent breakthroughs in convex optimization (ERM, LPs, Regression) - can be carried over to the realm of deep learning as well.
https://drops.dagstuhl.de/storage/00lipics/lipics-vol185-itcs2021/LIPIcs.ITCS.2021.63/LIPIcs.ITCS.2021.63.pdf
Deep learning theory
Nonconvex optimization