[link]
At NIPS 2017, Ali Rahimi was invited on stage to give a keynote after a paper he was on received the “Test of Time” award. While there, in front of several thousand researchers, he gave an impassioned argument for more rigor: more small problems to validate our assumptions, more visibility into why our optimization algorithms work the way they do. The now-famous catchphrase of the talk was “alchemy”; he argued that the machine learning community has been effective at finding things that work, but less effective at understanding why the techniques we use work. A central example he used in his talk is that of Batch Normalization: a now nearly-universal step in optimizing deep nets, but one where our accepted explanation of “reducing internal covariate shift” is less rigorous than one might hope. With apologies for the long preamble, this is the context in which today’s paper is such a welcome push in the direction of what Rahimi was advocating for - small, focused experimentation that tries to build up knowledge from principles, and, specifically, asks the question: “Does Batch Norm really work via reducing covariate shift”. To answer the question of whether internal covariate shift is a likely mechanism of the - empirically very solid - improved performance of Batch Norm, the authors do a few simple experience. First, and most straightforwardly, they train a basic convolutional net with and without BatchNorm, pick a layer, and visualize the activation distribution of that layer over time, both in the Batch Norm and non-Batch Norm case. While they saw the expected performance boost, the Batch Norm case didn’t seem to be meaningfully more stable over time, relative to the normal case. Second, the authors tested what would happen if they added non-zero-mean random noise *after* Batch Norm in the network. The upshot of this was that they were explicitly engineering internal covariate shift, and, if control thereof was the primary useful purpose of Batch Norm, you would expect that to neutralize BN’s good performance. In this experiment, while the authors did indeed see noisier, less stable activation distributions in the noise + BN case (in particular: look at layer 13 activations in the attached image), but noisy BN performed nearly as well as non-noisy, and meaningfully better than the standard model without noise, but also without BN. As a final test, they approached the idea of “internal covariate shift” from a different definitional standpoint. Maybe a better way of thinking about it is in terms of stability of your gradients, in the face of updates made by lower layers of the network. That is to say: each parameter of the network pushes itself in the direction of lower loss all else held equal, but in practice, you change lower-level parameters simultaneously, which could cause the directional change the higher-layer parameter thought it needed to be off. So, the authors calculated the “gradient delta” between the gradient the model trains on, and what the gradient would be if you estimated it *after* all of the lower layers of the model had updated, such that the distribution of inputs to that layer has changed. Although the expectation would be that this gradient delta is smaller for batch norm, in fact, the authors found that, if anything, the opposite was true. So, in the face of none of these ideas panning out, the authors then introduce the best idea they’ve found for what motivates BN’s improved performance: a smoothing out of the loss function that SGD is optimizing. A smoother curve means, generally speaking, that the magnitudes of your gradients will be smaller, and also that the value of the gradient will change more slowly (i.e. low second derivative). As support for this idea, they show really different results for BN vs standard models in terms of, for example, how predictive a gradient at one point is of a gradient taken after you take a step in the direction of the first gradient. BN has meaningfully more predictive gradients, tied to lower variance in the values of the loss function in the direction of the gradient. The logic for why the mechanism of BN would cause this outcome is a bit tied up in math that’s hard to explain without LaTeX visuals, but basically comes from the idea that Batch Norm decreases the magnitude of the gradient of each layer output with respect to individual weight parameters, by averaging out those magnitudes over the batch. As Rahimi said in his initial talk, a lot of modern modeling is “applying brittle optimization techniques to loss surfaces we don’t understand.” And, by and large, that is in fact true: it’s devilishly difficult to get a good handle on what loss surfaces are doing when they’re doing it in several-million-dimensional space. But, it being hard doesn’t mean we should just give up on searching for principles we can build our understanding on, and I think this paper is a really fantastic example of how that can be done well. |