[link]
Summary by Gavin Gray 7 years ago
When training a [VAE][] you will have an inference network $q(z|x)$. If you have another source of information you'd like to base the approximate posterior on, like some labels $s$, then you would make $q(z|x,s)$. But $q$ is a complicated function, and it can ignore $s$ if it wants, and still perform well. This paper describes an adversarial way to force $q$ to use $s$.
This is made more complicated in the paper, because $s$ is not necessarily a label, and in fact _is real and continuous_ (because it's easier to backprop in that case). In fact, we're going to _learn_ the representation of $s$, but force it to contain the label information using the training procedure. To be clear, with $x$ as our input (image or whatever):
$$
s = f_{s}(x)
$$
$$
\mu, \sigma = f_{z}(x,s)
$$
We sample $z$ using $\mu$ and $\sigma$ [according to the reparameterization trick, as this is a VAE][vae]:
$$
z \sim \mathcal{N}(\mu, \sigma)
$$
And then we use our decoder to turn these latent variables into images:
$$
\tilde{x} = \text{Dec}(z,s)
$$
Training Procedure
--------------------------
We are going to create four parallel loss functions, and incorporate a discriminator to train this:
1. Reconstruction loss plus variational regularizer; propagate a $x_1$ through the VAE to get $s_1$, $z_1$ (latent) and $\tilde{x}_{1}$.
2. Reconstruction loss with a different $s$:
1. Propagate $x_1'$, a different sample with the __same class__ as
$x_1$
2. Pass $z_1$ and $s_1'$ to your decoder.
3. As $s_1'$ _should_ include the label information, you should have reproduced $x_1$, so apply reconstruction loss to whatever your decoder has given you (call it $\tilde{x}_1'$).
3. Adversarial Loss encouraging realistic examples from the same class, regardless of $z$.
1. Propagate $x_2$ (totally separate example) through the network to get $s_2$.
2. Generate two $\tilde{x}_{2}$ variables, one with the prior by sampling from $p(z)$ and one using $z_{1}$.
3. Get the adversary to classify these as fake versus the real sample $x_{2}$.
This is pretty well described in Figure 1 in the paper.
Experiments show that $s$ ends up coding for the class, and $z$ codes for other stuff, like the angle of digits or line thickness. They also try to classify using $z$ and $s$ and show that $s$ is useful but $z$ is not (can only predict as well as chance). So, it works.
[vae]: https://jaan.io/unreasonable-confusion/
more
less