Variational Dropout Sparsifies Deep Neural Networks
Dmitry Molchanov
and
Arsenii Ashukha
and
Dmitry Vetrov
arXiv e-Print archive - 2017 via Local arXiv
Keywords:
stat.ML, cs.LG
First published: 2017/01/19 (7 years ago) Abstract: We explore recently proposed variational dropout technique which provided an
elegant Bayesian interpretation to dropout. We extend variational dropout to
the case when dropout rate is unknown and show that it can be found by
optimizing evidence variational lower bound. We show that it is possible to
assign and find individual dropout rates to each connection in DNN.
Interestingly such assignment leads to extremely sparse solutions both in
fully-connected and convolutional layers. This effect is similar to automatic
relevance determination (ARD) effect in empirical Bayes but has a number of
advantages. We report up to 128 fold compression of popular architectures
without a large loss of accuracy providing additional evidence to the fact that
modern deep architectures are very redundant.
The authors introduce their contribution as an alternative way to approximate the KL divergence between prior and variational posterior used in [Variational Dropout and the Local Reparameterization Trick][kingma] which allows unbounded variance on the multiplicative noise. When the noise variance parameter associated with a weight tends to infinity you can say that the weight is effectively being removed, and in their implementation this is what they do.
There are some important details differing from the [original algorithm][kingma] on per-weight variational dropout. For both methods we have the following initialization for each dense layer:
```
theta = initialize weight matrix with shape (number of input units, number of hidden units)
log_alpha = initialize zero matrix with shape (number of input units, number of hidden units)
b = biases initialized to zero with length the number of hidden units
```
Where `log_alpha` is going to parameterise the variational posterior variance.
In the original paper the algorithm was the following:
```
mean = dot(input, theta) + b # standard dense layer
# marginal variance over activations (eq. 10 in [original paper][kingma])
variance = dot(input^2, theta^2 * exp(log_alpha))
# sample from marginal distribution by scaling Normal
activations = mean + sqrt(variance)*unit_normal(number of output units)
```
The final step is a standard [reparameterization trick][shakir], but since it is a marginal distribution this is referred to as a local reparameterization trick (directly inspired by the [fast dropout paper][fast]).
The sparsifying algorithm starts with an alternative parameterisation for `log_alpha`
```
log_sigma2 = matrix filled with negative constant (default -8) with size (number of input units, number of hidden units)
log_alpha = log_sigma2 - log(theta^2)
log_alpha = log_alpha clipped between 8 and -8
```
The authors discuss this in section 4.1, the $\sigma_{ij}^2$ term corresponds to an additive noise variance on each weight with $\sigma_{ij}^2 = \alpha_{ij}\theta_{ij}^2$. Since this can then be reversed to define `log_alpha` the forward pass remains unchanged, but the variance of the gradient is reduced. It is quite a counter-intuitive trick, so much so I can't quite believe it works.
They then define a mask removing contributions to units where the noise variance has gone too high:
```
clip_mask = matrix shape of log_alpha, equals 1 if log_alpha is greater than thresh (default 3)
```
The clip mask is used to set elements of `theta` to zero, and then the forward pass is exactly the same as in the original paper.
The difference in the approximation to the KL divergence is illustrated in figure 1 of the paper; the sparsifying version tends to zero as the variance increases, which matches the true KL divergence. In the [original paper][kingma] the KL divergence would explode, forcing them to clip the variances at a certain point.
[kingma]: https://arxiv.org/abs/1506.02557
[shakir]: http://blog.shakirm.com/2015/10/machine-learning-trick-of-the-day-4-reparameterisation-tricks/
[fast]: http://proceedings.mlr.press/v28/wang13a.html