[link]
When training a [VAE][] you will have an inference network $q(z|x)$. If you have another source of information you'd like to base the approximate posterior on, like some labels $s$, then you would make $q(z|x,s)$. But $q$ is a complicated function, and it can ignore $s$ if it wants, and still perform well. This paper describes an adversarial way to force $q$ to use $s$. This is made more complicated in the paper, because $s$ is not necessarily a label, and in fact _is real and continuous_ (because it's easier to backprop in that case). In fact, we're going to _learn_ the representation of $s$, but force it to contain the label information using the training procedure. To be clear, with $x$ as our input (image or whatever): $$ s = f_{s}(x) $$ $$ \mu, \sigma = f_{z}(x,s) $$ We sample $z$ using $\mu$ and $\sigma$ [according to the reparameterization trick, as this is a VAE][vae]: $$ z \sim \mathcal{N}(\mu, \sigma) $$ And then we use our decoder to turn these latent variables into images: $$ \tilde{x} = \text{Dec}(z,s) $$ Training Procedure -------------------------- We are going to create four parallel loss functions, and incorporate a discriminator to train this: 1. Reconstruction loss plus variational regularizer; propagate a $x_1$ through the VAE to get $s_1$, $z_1$ (latent) and $\tilde{x}_{1}$. 2. Reconstruction loss with a different $s$: 1. Propagate $x_1'$, a different sample with the __same class__ as $x_1$ 2. Pass $z_1$ and $s_1'$ to your decoder. 3. As $s_1'$ _should_ include the label information, you should have reproduced $x_1$, so apply reconstruction loss to whatever your decoder has given you (call it $\tilde{x}_1'$). 3. Adversarial Loss encouraging realistic examples from the same class, regardless of $z$. 1. Propagate $x_2$ (totally separate example) through the network to get $s_2$. 2. Generate two $\tilde{x}_{2}$ variables, one with the prior by sampling from $p(z)$ and one using $z_{1}$. 3. Get the adversary to classify these as fake versus the real sample $x_{2}$. This is pretty well described in Figure 1 in the paper. Experiments show that $s$ ends up coding for the class, and $z$ codes for other stuff, like the angle of digits or line thickness. They also try to classify using $z$ and $s$ and show that $s$ is useful but $z$ is not (can only predict as well as chance). So, it works. [vae]: https://jaan.io/unreasonable-confusion/
Your comment:
|