CSC2541 Winter 2021
Topics in Machine Learning:
Neural Net Training Dynamics

Overview

Neural nets have achieved amazing results over the past decade in domains as broad as vision, speech, language understanding, medicine, robotics, and game playing. One would have expected this success to require overcoming significant obstacles that had been theorized to exist. After all, the optimization landscape is nonconvex, highly nonlinear, and high-dimensional, so why are we able to train these networks? In many cases, they have far more than enough parameters to memorize the data, so why do they generalize well? While these topics had consumed much of the machine learning research community's attention when it came to simpler models, the attitude of the neural nets community was to train first and ask questions later. Apparently this worked.

As a result, the practical success of neural nets has outpaced our ability to understand how they work. This class is about developing the conceptual tools to understand what happens when a neural net trains. Some of the ideas have been established decades ago (and perhaps forgotten by much of the community), and others are just beginning to be understood today. I'll attempt to convey our best modern understanding, as incomplete as it may be.

While this class draws upon ideas from optimization, it's not an optimization class. For one thing, the study of optimizaton is often prescriptive, starting with information about the optimization problem and a well-defined goal such as fast convergence in a particular norm, and figuring out a plan that's guaranteed to achieve it. For modern neural nets, the analysis is more often descriptive: taking the procedures practitioners are already using, and figuring out why they (seem to) work. Hopefully this understanding will let us improve the algorithms.

Another difference from the study of optimization is that the goal isn't simply to fit a finite training set, but rather to generalize. Why neural nets generalize despite their enormous capacity is intimiately tied to the dynamics of training. Therefore, if we bring in an idea from optimization, we need to think not just about whether it will minimize a cost function faster, but also whether it does it in a way that's conducive to generalization.

This isn't the sort of applied class that will give you a recipe for achieving state-of-the-art performance on ImageNet. Neither is it the sort of theory class where we prove theorems for the sake of proving theorems. Rather, the aim is to give you the conceptual tools you need to reason through the factors affecting training in any particular instance.

Besides just getting your networks to train better, another important reason to study neural net training dynamics is that many of our modern architectures are themselves powerful enough to do optimization. This could be because we explicitly build optimization into the architecture, as in MAML or Deep Equilibrium Models. Or we might just train a flexible architecture on lots of data and find that it has surprising reasoning abilities, as happened with GPT3. Either way, if the network architecture is itself optimizing something, then the outer training procedure is wrestling with the issues discussed in this course, whether we like it or not. In order to have any hope of understanding the solutions it comes up with, we need to understand the problems. Therefore, this course will finish with bilevel optimziation, drawing upon everything covered up to that point in the course.

Where and When

Time: Thursdays 2:10-5:00

Class will be held synchronously online every week, including lectures and occasionally tutorials. We have 3 hours scheduled for lecture and/or tutorial. Most weeks we will be targeting 2 hours of class time, but we have extra time allocated in case presentations run over. Students are encouraged to attend class each week.

Online delivery. Lectures will be delivered synchronously via Zoom, and recorded for asynchronous viewing by enrolled students. Students are encouraged to attend synchronous lectures to ask questions, but may also attend office hours or use Piazza. All information about attending virtual lectures, tutorials, and office hours will be sent to enrolled students through Quercus.

Course videos and materials belong to your instructor, the University, and/or other source depending on the specific facts of each situation, and are protected by copyright. In this course, you are permitted to download session videos and materials for your own academic use, but you should not copy, share, or use them for any other purpose without the explicit permission of the instructor. For questions about recording and use of videos in which you appear please contact your instructor.

Teaching Staff

Assignments and Grading

Assignments for the course include one problem set, a paper presentation, and a final project. The marking scheme is as follows:

The problem set will give you a chance to practice the content of the first three lectures, and will be due on Feb 10. It is individual work. Here are the materials:

For the Colab notebook and paper presentation, you will form a group of 2-3 and pick one paper from a list. Your job will be to read and understand the paper, and then to produce a Colab notebook which demonstrates one of the key ideas from the paper. A sign-up sheet will be distributed via email. The details of the assignment are here.

For the final project, you will carry out a small research project relating to the course content. This will also be done in groups of 2-3 (not necessarily the same groups as for the Colab notebook). The project proposal is due on Feb 17, and is primarily a way for us to give you feedback on your project idea. The final report is due April 7. More details can be found in the project handout.

Software

For this class, we'll use Python and the JAX deep learning framework. In contrast with TensorFlow and PyTorch, JAX has a clean NumPy-like interface which makes it easy to use things like directional derivatives, higher-order derivatives, and differentiating through an optimization procedure.

There are several neural net libraries built on top of JAX. Depending what you're trying to do, you have several options:

You are welcome to use whatever language and framework you like for the final project. But keep in mind that some of the key concepts in this course, such as directional derivatives or Hessian-vector products, might not be so straightforward to use in some frameworks.

Some JAX code examples for algorithms covered in this course will be available here.

Schedule

This is a tentative schedule, which will likely change as the course goes on.

Overwhelmed? Check out CSC2541 for the Busy.

# Date Topic Readings
1 1/14

A Toy Model: Linear Regression [Slides]

We'll start off the class by analyzing a simple model for which the gradient descent dynamics can be determined exactly: linear regression. Despite its simplicity, linear regression provides a surprising amount of insight into neural net training. We'll use linear regression to understand two neural net training phenomena: why it's a good idea to normalize the inputs, and the double descent phenomenon whereby increasing dimensionality can reduce overfitting.

Tutorial: JAX, part 1 [Colab]

2 1/21

Taylor Approximations [Slides]

Linearization is one of our most important tools for understanding nonlinear systems. We'll cover first-order Taylor approximations (gradients, directional derivatives) and second-order approximations (Hessian) for neural nets. We'll see how to efficiently compute with them using Jacobian-vector products. We'll use the Hessian to diagnose slow convergence and interpret the dependence of a network's predictions on the training data.

Tutorial: JAX, part 2 [Colab]

3 1/28

Metrics [Slides]

Metrics give a local notion of distance on a manifold. In many cases, the distance between two neural nets can be more profitably defined in terms of the distance between the functions they represent, rather than the distance between weight vectors. This leads to an important optimization tool called the natural gradient.

4 2/4

Second-Order Optimization [Slides]

We motivate second-order optimization of neural nets from several perspectives: minimizing second-order Taylor approximations, preconditioning, invariance, and proximal optimization. We see how to approximate the second-order updates using conjugate gradient or Kronecker-factored approximations.

5 2/11

Adaptive Gradient Methods, Normalization, and Weight Decay [Slides]

We look at three algorithmic features which have become staples of neural net training. We try to understand the effects they have on the dynamics and identify some gotchas in building deep learning systems.

6 2/18

Infinite Limits and Overparameterization [Slides]

Systems often become easier to analyze in the limit. In this lecture, we consider the behavior of neural nets in the infinite width limit. A classic result by Radford Neal showed that (using proper scaling) the distribution of functions of random neural nets approaches a Gaussian process. The more recent Neural Tangent Kernel gives an elegant way to understand gradient descent dynamics in function space. Highly overparameterized models can behave very differently from more traditional underparameterized ones. Time permitting, we'll also consider the limit of infinite depth.

7 2/25

Stochastic Optimization and Scaling [Slides]

When can we take advantage of parallelism to train neural nets? Which optimization techniques are useful at which batch sizes? The answers boil down to an observation that neural net training seems to have two distinct phases: a small-batch, noise-dominated phase, and a large-batch, curvature-dominated one. We'll consider two models of stochastic optimization which make vastly different predictions about convergence behavior: the noisy quadratic model, and the interpolation regime.

8 3/4

Implicit Regularization and Bayesian Inference [Slides]

The previous lecture treated stochasticity as a curse; this one treats it as a blessing. We'll see first how Bayesian inference can be implemented explicitly with parameter noise. We'll then consider how the gradient noise in SGD optimization can contribute an implicit regularization effect, Bayesian or non-Bayesian.

9 3/11

Dynamical Systems and Momentum [Slides]

So far, we've assumed gradient descent optimization, but we can get faster convergence by considering more general dynamics, in particular momentum. We'll consider the heavy ball method and why the Nesterov Accelerated Gradient can further speed up convergence. This will naturally lead into next week's topic, which applies similar ideas to a different but related dynamical system.

10 3/18

Differentiable Games (Lecture by Guodong Zhang) [Slides]

Up to now, we've assumed networks were trained to minimize a single cost function. Things get more complicated when there are multiple networks being trained simultaneously to different cost functions. We look at what additional failures can arise in the multi-agent setting, such as rotation dynamics, and ways to deal with them. We'll mostly focus on minimax optimization, or zero-sum games.

11 3/25

Bilevel Optimization I [Slides]

Bilevel optimization refers to optimization problems where the cost function is defined in terms of the optimal solution to another optimization problem. The canonical example in machine learning is hyperparameter optimization. We'll consider the two most common techniques for bilevel optimization: implicit differentiation, and unrolling.

12 4/1

Bilevel Optimization II [Slides]

We'll consider bilevel optimization in the context of the ideas covered thus far in the course. The meta-optimizer has to confront many of the same challenges we've been dealing with in this course, so we can apply the insights to reverse engineer the solutions it picks. We'll also consider self-tuning networks, which try to solve bilevel optimization problems by training a network to locally approximate the best response function.

Student Presentations

Here are the slides and Colab notebooks from the student paper presentations.
Date Presenters Paper Materials
2/25 Kimia Hamidieh, Nathan Ng, Haoran Zhang A. Jacot, F. Gabriel, and C. Hongler. Neural tangent kernel: Convergence and generalization in neural networks. NeurIPS 2018. [Slides] [Colab]
3/4 Chris Zhang, Dami Choi, Anqi (Joyce) Yang P. Nakkiran, B. Neyshabur, and H. Sedghi. The deep bootstrap framework: Good online learners are good offline generalizers. ICLR 2021. [Slides] [Colab]
3/4 Borys Bryndak, Sergio Casas, and Sean Segal S. McCandish, J. Kaplan, D. Amodei, and the OpenAI Dota Team. An empirical model of large-batch training. 2018. [Slides] [Colab]
3/11 Kelvin Wong, Siva Manivasagam, and Amanjit Singh Kainth C. Wei, S. Kakade, and T. Ma. The implicit and explicit regularization effects of dropout. ICML 2020. [Slides] [Colab]
3/11 Honghua Dong and Tianxing Li J. Cohen, S. Kaur, Y. Li, J. Z. Kolter, and A. Talwalkar. Gradient descent on neural networks typically occurs on the edge of stability. ICLR 2021. [Slides] [Colab]
3/18 Yuwen Xiong, Andrew Liao, and Jingkang Wang I. Sutskever, J. Martens, G. Dahl, and G. Hinton. On the importance of initialization and momentum in deep learning [Slides] [Colab 1] [Colab 2]
3/18 Jenny Bao, Sheldon Huang, and Skylar Hao A. M. Saxe, J. L. McClelland, and S. Ganguli. A mathematical theory of semantic development in deep neural networks [Slides] [Colab]
3/25 James Tu, Yangjun Ruan, and Jonah Philion C. Maddison, D. Paulin, Y.-W. Teh, B. O'Donoghue, and A. Doucet. Hamiltonian descent methods. [Slides] [Colab]
4/1 Haoping Xu, Zhihuan Yu, and Jingcheng Niu D. Maclaurin, D. Duvenaud, and R. P. Adams. Gradient-based Hyperparameter Optimization through Reversible Learning. [Slides] [Colab]
4/1 Alex Adam, Keiran Paster, and Jenny (Jingyi) Liu P.-W. Koh and P. Liang. Understanding black-box predictions via influence functions. [Slides] [Colab]
4/1 Alex Wang, Gary Leung, and Sasha Doubov S. Bai, J. Z. Kolter, and V. Koltun. Deep equilibrium models [Slides] [Colab]
4/1 Shihao Ma, Yichun Zhang, and Zilun Zhang C. Finn, P. Abbeel, and S. Levine. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks [Slides] [Colab]