First published: 2018/10/01 (6 years ago) Abstract: In spite of remarkable progress in deep latent variable generative modeling,
training still remains a challenge due to a combination of optimization and
generalization issues. In practice, a combination of heuristic algorithms (such
as hand-crafted annealing of KL-terms) is often used in order to achieve the
desired results, but such solutions are not robust to changes in model
architecture or dataset. The best settings can often vary dramatically from one
problem to another, which requires doing expensive parameter sweeps for each
new case. Here we develop on the idea of training VAEs with additional
constraints as a way to control their behaviour. We first present a detailed
theoretical analysis of constrained VAEs, expanding our understanding of how
these models work. We then introduce and analyze a practical algorithm termed
Generalized ELBO with Constrained Optimization, GECO. The main advantage of
GECO for the machine learning practitioner is a more intuitive, yet principled,
process of tuning the loss. This involves defining of a set of constraints,
which typically have an explicit relation to the desired model performance, in
contrast to tweaking abstract hyper-parameters which implicitly affect the
model behavior. Encouraging experimental results in several standard datasets
indicate that GECO is a very robust and effective tool to balance
reconstruction and compression constraints.
The paper provides derivations and intuitions about the learning dynamics for VAEs based on observations about [$\beta$-VAEs][beta]. Using this they derive an alternative way to constrain the training of VAEs that doesn't require typical heuristics, such as warmup or adding noise to the data.
How exactly would this change a typical implementation? Typically, SGD is used to [optimize the ELBO directly](https://github.com/pytorch/examples/blob/master/vae/main.py#L91-L95). Using GECO, I keep a moving average of my constraint $C$ (chosen based on what I want the VAE to do, but it can be just the likelihood plus a tolerance parameter) and use that to calculate Lagrange multipliers, which control the weighting of the constraint to the loss. [This implementation](https://github.com/denproc/Taming-VAEs/blob/master/train.py#L83-L97) from a class project appears to be correct.
With the stabilization of training, I can't help but think of this as batchnorm for VAEs.
[beta]: https://openreview.net/forum?id=Sy2fzU9gl