[link]
In this note, I'll implement the [Stochastically Unbiased Marginalization Objective (SUMO)](https://openreview.net/forum?id=SylkYeHtwr) to estimate the log-partition function of an energy funtion. Estimation of log-partition function has many important applications in machine learning. Take latent variable models or Bayeisian inference. The log-partition function of the posterior distribution $$p(z|x)=\frac{1}{Z}p(x|z)p(z)$$ is the log-marginal likelihood of the data $$\log Z = \log \int p(x|z)p(z)dz = \log p(x)$$. More generally, let $U(x)$ be some energy function which induces some density function $p(x)=\frac{e^{-U(x)}}{\int e^{-U(x)} dx}$. The common practice is to look at a variational form of the log-partition function, $$ \log Z = \log \int e^{-U(x)}dx = \max_{q(x)}\mathbb{E}[-U(x)-\log q(x)] \nonumber $$ Plugging in an arbitrary $q$ would normally yield a strict lower bound, which means $$ \frac{1}{n}\sum_{i=1}^n \left(-U(x_i) - \log q(x_i)\right) \nonumber $$ for $x_i$ sampled *i.i.d.* from $q$, would be a biased estimate for $\log Z$. In particular, it would be an underestimation. To see this, lets define the energy function $U$ as follows: $$ U(x_1,x_2)= - \log \left(\frac{1}{2}\cdot e^{-\frac{(x_1+2)^2 + x_2^2}{2}} + \frac{1}{2}\cdot\frac{1}{4}e^{-\frac{(x_1-2)^2 + x_2^2}{8}}\right) \nonumber $$ It is not hard to see that $U$ is the energy function of a mixture of Gaussian distribution $\frac{1}{2}\mathcal{N}([-2,0], I) + \frac{1}{2}\mathcal{N}([2,0], 4I)$ with a normalizing constant $Z=2\pi\approx6.28$ and $\log Z\approx1.8379$. ```python def U(x): x1 = x[:,0] x2 = x[:,1] d2 = x2 ** 2 return - np.log(np.exp(-((x1+2) ** 2 + d2)/2)/2 + np.exp(-((x1-2) ** 2 + d2)/8)/4/2) ``` To visualize the density corresponding to the energy $p(x)\propto e^{-U(x)}$ ```python xx = np.linspace(-5,5,200) yy = np.linspace(-5,5,200) X = np.meshgrid(xx,yy) X = np.concatenate([X[0][:,:,None], X[1][:,:,None]], 2).reshape(-1,2) unnormalized_density = np.exp(-U(X)).reshape(200,200) plt.imshow(unnormalized_density) plt.axis('off') ``` https://i.imgur.com/CZSyIQp.png As a sanity check, lets also visualize the density of the mixture of Gaussians. ```python N1, N2 = mvn([-2,0], 1), mvn([2,0], 4) density = (np.exp(N1.logpdf(X))/2 + np.exp(N2.logpdf(X))/2).reshape(200,200) plt.imshow(density) plt.axis('off') print(np.allclose(unnormalized_density / density - 2*np.pi, 0)) ``` `True` https://i.imgur.com/g4inQxB.png Now if we estimate the log-partition function by estimating the variational lower bound, we get ```python q = mvn([0,0],5) xs = q.rvs(10000*5) elbo = - U(xs) - q.logpdf(xs) plt.hist(elbo, range(-5,10)) print("Estimate: %.4f / Ground true: %.4f" % (elbo.mean(), np.log(2*np.pi))) print("Empirical variance: %.4f" % elbo.var()) ``` `Estimate: 1.4595 / Ground true: 1.8379` `Empirical variance: 0.9921` https://i.imgur.com/vFzutuY.png The lower bound can be tightened via [importance sampling): $$ \log \int e^{-U(x)} dx \geq \mathbb{E}_{q^K}\left[\log\left(\frac{1}{K}\sum_{j=1}^K \frac{e^{-U(x_j)}}{q(x_j)}\right)\right] \nonumber $$ > This bound is tighter for larger $K$ partly due to the [concentration of the average](https://arxiv.org/pdf/1906.03708.pdf) inside of the $\log$ function: when the random variable is more deterministic, using a local linear approximation near its mean is more accurate as there's less "mass" outside of some neighborhood of the mean. Now if we use this new estimator with $K=5$ ```python k = 5 xs = q.rvs(10000*k) elbo = - U(xs) - q.logpdf(xs) iwlb = elbo.reshape(10000,k) iwlb = np.log(np.exp(iwlb).mean(1)) plt.hist(iwlb, range(-5,10)) print("Estimate: %.4f / Ground true: %.4f" % (iwlb.mean(), np.log(2*np.pi))) print("Empirical variance: %.4f" % iwlb.var()) ``` `Estimate: 1.7616 / Ground true: 1.8379` `Empirical variance: 0.1544` https://i.imgur.com/sCcsQd4.png We see that both the bias and variance decrease. Finally, we use the [Stochastically Unbiased Marginalization Objective](https://openreview.net/pdf?id=SylkYeHtwr) (SUMO), which uses the *Russian Roulette* estimator to randomly truncate a telescoping series that converges in expectation to the log partition function. Let $\text{IWAE}_K = \log\left(\frac{1}{K}\sum_{j=1}^K \frac{e^{-U(x_j)}}{q(x_j)}\right)$ be the importance-weighted estimator, and $\Delta_K = \text{IWAE}_{K+1} - \text{IWAE}_K$ be the difference (which can be thought of as some form of correction). The SUMO estimator is defined as $$ \text{SUMO} = \text{IWAE}_1 + \sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \nonumber $$ where $K\sim p(K)=\mathbb{P}(\mathcal{K}=K)$. To see why this is an unbiased estimator, $$ \begin{align*} \mathbb{E}[\text{SUMO}] &= \mathbb{E}\left[\text{IWAE}_1 + \sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right] \nonumber\\ &= \mathbb{E}_{x's}\left[\text{IWAE}_1 + \mathbb{E}_{K}\left[\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right]\right] \nonumber \end{align*} $$ The inner expectation can be further expanded $$ \begin{align*} \mathbb{E}_{K}\left[\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \right] &= \sum_{K=1}^\infty P(K)\sum_{k=1}^K \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \\ &= \sum_{k=1}^\infty \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \sum_{K=k}^\infty P(K) \\ &= \sum_{k=1}^\infty \frac{\Delta_K}{\mathbb{P}(\mathcal{K}\geq k)} \mathbb{P}(\mathcal{K}\geq k) \\ &= \sum_{k=1}^\infty\Delta_K \\ &= \text{IWAE}_{2} - \text{IWAE}_1 + \text{IWAE}_{3} - \text{IWAE}_2 + ... = \lim_{k\rightarrow\infty}\text{IWAE}_{k}-\text{IWAE}_1 \end{align*} $$ which shows $\mathbb{E}[\text{SUMO}] = \mathbb{E}[\text{IWAE}_\infty] = \log Z$. > (N.B.) Some care needs to be taken care of for taking the limit. See the paper for more formal derivation. A choice of $P(K)$ proposed in the paper satisfy $\mathbb{P}(\mathcal{K}\geq K)=\frac{1}{K}$. We can sample such a $K$ easily using the [inverse CDF](https://en.wikipedia.org/wiki/Inverse_transform_sampling), $K=\lfloor\frac{u}{1-u}\rfloor$ where $u$ is sampled uniformly from the interval $[0,1]$. Now putting things all together, we can estimate the log-partition using SUMO. ```python count = 0 bs = 10 iwlb = list() while count <= 1000000: u = np.random.rand(1) k = np.ceil(u/(1-u)).astype(int)[0] xs = q.rvs(bs*(k+1)) elbo = - U(xs) - q.logpdf(xs) iwlb_ = elbo.reshape(bs, k+1) iwlb_ = np.log(np.cumsum(np.exp(iwlb_), 1) / np.arange(1,k+2)) iwlb_ = iwlb_[:,0] + ((iwlb_[:,1:k+1] - iwlb_[:,0:k]) * np.arange(1,k+1)).sum(1) count += bs * (k+1) iwlb.append(iwlb_) iwlb = np.concatenate(iwlb) plt.hist(iwlb, range(-5,10)) print("Estimate: %.4f / Ground true: %.4f" % (iwlb.mean(), np.log(2*np.pi))) print("Empirical variance: %.4f" % iwlb.var()) ``` `Estimate: 1.8359 / Ground true: 1.8379` `Empirical variance: 4.1794` https://i.imgur.com/04kPKo5.png Indeed the empirical average is quite close to the true log-partition of the energy function. However we can also see that the distribution of the estimator is much more spread-out. In fact, it is very heavy-tailed. Note that I did not tune the proposal distribution $q$ based on the ELBO, or IWAE or SUMO. In the paper, the authors propose to tune $q$ to minimize the variance of the $\text{SUMO}$ estimator, which might be an interesting trick to look at next. (Reposted, see more details and code from https://www.chinweihuang.com/pages/sumo) |