[link]
This is an interestingly pragmatic paper that makes a super simple observation. Often, we may want a usable network with fewer parameters, to make our network more easily usable on small devices. It's been observed (by these same authors, in fact), that pruned networks can achieve comparable weights to their fully trained counterparts if you rewind and retrain from early in the training process, to compensate for the loss of the (not ultimately important) pruned weights. This observation has been dubbed the "Lottery Ticket Hypothesis", after the idea that there's some small effective subnetwork you can find if you sample enough networks. Given these two facts - the usefulness of pruning, and the success of weight rewinding - the authors explore the effectiveness of various ways to train after pruning. Current standard practice is to prune low-magnitude weights, and then continue training remaining weights from values they had at pruning time, keeping the final learning rate of the network constant. The authors find that: 1. Weight rewinding, where you rewind weights to *near* their starting value, and then retrain using the learning rates of early in training, outperforms fine tuning from the place weights were when you pruned but, also 2. Learning rate rewinding, where you keep weights as they are, but rewind learning rates to what they were early in training, are actually the most effective for a given amount of training time/search cost To me, this feels a little bit like burying the lede: the takeaway seems to be that when you prune, it's beneficial to make your network more "elastic" (in the metaphor-to-neuroscience sense) so it can more effectively learn to compensate for the removed neurons. So, what was really valuable in weight rewinding was the ability to "heat up" learning on a smaller set of weights, so they could adapt more quickly. And the fact that learning rate rewinding works better than weight rewinding suggests that there is value in the learned weights after all, that value is just outstripped by the benefit of rolling back to old learning rates. All in all, not a super radical conclusion, but a useful and practical one to have so clearly laid out in a paper. ![]() |
[link]
In this note, I'll implement the [Stochastically Unbiased Marginalization Objective (SUMO)](https://openreview.net/forum?id=SylkYeHtwr) to estimate the log-partition function of an energy funtion. Estimation of log-partition function has many important applications in machine learning. Take latent variable models or Bayeisian inference. The log-partition function of the posterior distribution $$p(z|x)=\frac{1}{Z}p(x|z)p(z)$$ is the log-marginal likelihood of the data $$\log Z = \log \int p(x|z)p(z)dz = \log p(x)$$. More generally, let $U(x)$ be some energy function which induces some density function $p(x)=\frac{e^{-U(x)}}{\int e^{-U(x)} dx}$. The common practice is to look at a variational form of the log-partition function, $$ \log Z = \log \int e^{-U(x)}dx = \max_{q(x)}\mathbb{E}[-U(x)-\log q(x)] \nonumber $$ Plugging in an arbitrary $q$ would normally yield a strict lower bound, which means $$ \frac{1}{n}\sum_{i=1}^n \left(-U(x_i) - \log q(x_i)\right) \nonumber $$ for $x_i$ sampled *i.i.d.* from $q$, would be a biased estimate for $\log Z$. In particular, it would be an underestimation. To see this, lets define the energy function $U$ as follows: $$ U(x_1,x_2)= - \log \left(\frac{1}{2}\cdot e^{-\frac{(x_1+2)^2 + x_2^2}{2}} + \frac{1}{2}\cdot\frac{1}{4}e^{-\frac{(x_1-2)^2 + x_2^2}{8}}\right) \nonumber $$ It is not hard to see that $U$ is the energy function of a mixture of Gaussian distribution $\frac{1}{2}\mathcal{N}([-2,0], I) + \frac{1}{2}\mathcal{N}([2,0], 4I)$ with a normalizing constant $Z=2\pi\approx6.28$ and $\log Z\approx1.8379$. ```python def U(x): x1 = x[:,0] x2 = x[:,1] d2 = x2 ** 2 return - np.log(np.exp(-((x1+2) ** 2 + d2)/2)/2 + np.exp(-((x1-2) ** 2 + d2)/8)/4/2) ``` To visualize the density corresponding to the energy $p(x)\propto e^{-U(x)}$ ```python xx = np.linspace(-5,5,200) yy = np.linspace(-5,5,200) X = np.meshgrid(xx,yy) X = np.concatenate([X[0][:,:,None], X[1][:,:,None]], 2).reshape(-1,2) unnormalized_density = np.exp(-U(X)).reshape(200,200) plt.imshow(unnormalized_density) plt.axis('off') ``` https://i.imgur.com/CZSyIQp.png As a sanity check, lets also visualize the density of the mixture of Gaussians. ```python N1, N2 = mvn([-2,0], 1), mvn([2,0], 4) density = (np.exp(N1.logpdf(X))/2 + np.exp(N2.logpdf(X))/2).reshape(200,200) plt.imshow(density) plt.axis('off') print(np.allclose(unnormalized_density / density - 2*np.pi, 0)) ``` `True` https://i.imgur.com/g4inQxB.png Now if we estimate the log-partition function by estimating the variational lower bound, we get ```python q = mvn([0,0],5) xs = q.rvs(10000*5) elbo = - U(xs) - q.logpdf(xs) plt.hist(elbo, range(-5,10)) print("Estimate: %.4f / Ground true: %.4f" % (elbo.mean(), np.log(2*np.pi))) print("Empirical variance: %.4f" % elbo.var()) ``` `Estimate: 1.4595 / Ground true: 1.8379` `Empirical variance: 0.9921` https://i.imgur.com/vFzutuY.png The lower bound can be tightened via [importance sampling): $$ \log \int e^{-U(x)} dx \geq \mathbb{E}_{q^K}\left[\log\left(\frac{1}{K}\sum_{j=1}^K \frac{e^{-U(x_j)}}{q(x_j)}\right)\right] \nonumber $$ > This bound is tighter for larger $K$ partly due to the [concentration of the average](https://arxiv.org/pdf/1906.03708.pdf) inside of the $\log$ function: when the random variable is more deterministic, using a local linear approximation near its mean is more accurate as there's less "mass" outside of some neighborhood of the mean. Now if we use this new estimator with $K=5$ ```python k = 5 xs = q.rvs(10000*k) elbo = - U(xs) - q.logpdf(xs) iwlb = elbo.reshape(10000,k) iwlb = np.log(np.exp(iwlb).mean(1)) plt.hist(iwlb, range(-5,10)) print("Estimate: %.4f / Ground true: %.4f" % (iwlb.mean(), np.log(2*np.pi))) print("Empirical variance: %.4f" % iwlb.var()) ``` `Estimate: 1.7616 / Ground true: 1.8379` `Empirical variance: 0.1544` https://i.imgur.com/sCcsQd4.png We see that both the bias and variance decrease. Finally, we use the [Stochastically Unbiased Marginalization Objective](https://openreview.net/pdf?id=SylkYeHtwr) (SUMO), which uses the *Russian Roulette* estimator to randomly truncate a telescoping series that converges in expectation to the log partition function. Let $\text{IWAE}_K = \log\left(\frac{1}{K}\sum_{j=1}^K \frac{e^{-U(x_j)}}{q(x_j)}\right)$ be the importance-weighted estimator, and $\Delta_K = \text{IWAE}_{K+1} - \text{IWAE}_K$ be the difference (which can be thought of as some form of correction). The SUMO estimator is defined as $$ \text{SUMO} = \text{IWAE}_1 + \sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \nonumber $$ where $K\sim p(K)=\mathbb{P}(\mathcal{K}=K)$. To see why this is an unbiased estimator, $$ \begin{align*} \mathbb{E}[\text{SUMO}] &= \mathbb{E}\left[\text{IWAE}_1 + \sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right] \nonumber\\ &= \mathbb{E}_{x's}\left[\text{IWAE}_1 + \mathbb{E}_{K}\left[\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right]\right] \nonumber \end{align*} $$ The inner expectation can be further expanded $$ \begin{align*} \mathbb{E}_{K}\left[\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right] &= \sum_{K=1}^\infty P(K)\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \\ &= \sum_{k=1}^\infty \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \sum_{K=k}^\infty P(K) \\ &= \sum_{k=1}^\infty \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \mathbb{P}(\mathcal{K}\geq k) \\ &= \sum_{k=1}^\infty\Delta_K \\ &= \text{IWAE}_{2} - \text{IWAE}_1 + \text{IWAE}_{3} - \text{IWAE}_2 + ... = \lim_{k\rightarrow\infty}\text{IWAE}_{k}-\text{IWAE}_1 \end{align*} $$ which shows $\mathbb{E}[\text{SUMO}] = \mathbb{E}[\text{IWAE}_\infty] = \log Z$. > (N.B.) Some care needs to be taken care of for taking the limit. See the paper for more formal derivation. A choice of $P(K)$ proposed in the paper satisfy $\mathbb{P}(\mathcal{K}\geq K)=\frac{1}{K}$. We can sample such a $K$ easily using the [inverse CDF](https://en.wikipedia.org/wiki/Inverse_transform_sampling), $K=\lfloor\frac{u}{1-u}\rfloor$ where $u$ is sampled uniformly from the interval $[0,1]$. Now putting things all together, we can estimate the log-partition using SUMO. ```python count = 0 bs = 10 iwlb = list() while count <= 1000000: u = np.random.rand(1) k = np.ceil(u/(1-u)).astype(int)[0] xs = q.rvs(bs*(k+1)) elbo = - U(xs) - q.logpdf(xs) iwlb_ = elbo.reshape(bs, k+1) iwlb_ = np.log(np.cumsum(np.exp(iwlb_), 1) / np.arange(1,k+2)) iwlb_ = iwlb_[:,0] + ((iwlb_[:,1:k+1] - iwlb_[:,0:k]) * np.arange(1,k+1)).sum(1) count += bs * (k+1) iwlb.append(iwlb_) iwlb = np.concatenate(iwlb) plt.hist(iwlb, range(-5,10)) print("Estimate: %.4f / Ground true: %.4f" % (iwlb.mean(), np.log(2*np.pi))) print("Empirical variance: %.4f" % iwlb.var()) ``` `Estimate: 1.8359 / Ground true: 1.8379` `Empirical variance: 4.1794` https://i.imgur.com/04kPKo5.png Indeed the empirical average is quite close to the true log-partition of the energy function. However we can also see that the distribution of the estimator is much more spread-out. In fact, it is very heavy-tailed. Note that I did not tune the proposal distribution $q$ based on the ELBO, or IWAE or SUMO. In the paper, the authors propose to tune $q$ to minimize the variance of the $\text{SUMO}$ estimator, which might be an interesting trick to look at next. (Reposted, see more details and code from https://www.chinweihuang.com/pages/sumo) ![]() |