The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
Chris J. Maddison
and
Andriy Mnih
and
Yee Whye Teh
arXiv e-Print archive - 2016 via Local arXiv
Keywords:
cs.LG, stat.ML
First published: 2016/11/02 (8 years ago) Abstract: The reparameterization trick enables the optimization of large scale
stochastic computation graphs via gradient descent. The essence of the trick is
to refactor each stochastic node into a differentiable function of its
parameters and a random variable with fixed distribution. After refactoring,
the gradients of the loss propagated by the chain rule through the graph are
low variance unbiased estimators of the gradients of the expected loss. While
many continuous random variables have such reparameterizations, discrete random
variables lack continuous reparameterizations due to the discontinuous nature
of discrete states. In this work we introduce concrete random variables --
continuous relaxations of discrete random variables. The concrete distribution
is a new family of distributions with closed form densities and a simple
reparameterization. Whenever a discrete stochastic node of a computation graph
can be refactored into a one-hot bit representation that is treated
continuously, concrete stochastic nodes can be used with automatic
differentiation to produce low-variance biased gradients of objectives
(including objectives that depend on the log-likelihood of latent stochastic
nodes) on the corresponding discrete graph. We demonstrate their effectiveness
on density estimation and structured prediction tasks using neural networks.
This paper presents a way to differentiate through discrete random variables by replacing them with continuous random variables. Say you have a discrete [categorical variable][cat] and you're sampling it with the [Gumbel trick][gumbel] like this ($G_k$ is a Gumbel distributed variable and $\boldsymbol{\alpha}/\sum_k \alpha_k$ are our categorical probabilities):
$$
z = \text{one_hot} \left( \underset{k}{\text{arg max}} [ G_k + \log \alpha_k ] \right)
$$
This paper replaces the one hot and argmax with a softmax, and they add a $\lambda$ variable to control the "temperature". As $\lambda$ tends to zero it will equal the above equation.
$$
z = \text{softmax} \left( \frac{ G_k + \log \alpha_k }{\lambda} \right)
$$
I made [some notes][nb] on how this process works, if you'd like more intuition.
Comparison to [Gumbel-softmax][gs]
--------------------------------------------
These papers are proposed precisely the same distribution with notation changes ([noted there][gs]). Both papers also reference each other and the differences. Although, they mention differences in the variatonal objectives to the Gumbel-softmax. This paper also compares to [VIMCO][], which is probably a harder benchmark to compare against (multi-sample versus single sample).
The results in both papers compare to SOTA score function based estimators and both report high scoring results (often the best). There are some details about implementations to consider though, such as scheduling and exactly how to define the variational objective.
[cat]: https://en.wikipedia.org/wiki/Categorical_distribution
[gumbel]: https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
[gs]: http://www.shortscience.org/paper?bibtexKey=journals/corr/JangGP16
[nb]: https://gist.github.com/gngdb/ef1999ce3a8e0c5cc2ed35f488e19748
[vimco]: https://arxiv.org/abs/1602.06725