[link]
TLDR; The authors adopt Generative Adversarial Networks (GANs) to RNNs and train a discriminator to distinguish between sequences generated using teacher forcing (feeding ground truth inputs to the RNN) and scheduled sampling (feeding generated outputs as the next inputs). The inputs to the discriminator are both the predictions and the hidden states of the generative RNN. The generator is trained to fool the discriminator, forcing the dynamics of teacher forcing and scheduled sampling to become more similar. This procedure acts as regularizer, and results in better sample quality and generalization, particularly for long sequences. The authors evaluate their framework on Language Model (PTB), Pixel Generation (Sequential MNIST), Handwriting Generation, and Musisc Synthesis. ### Key Points - Problem: During inference, errors in an RNN easily compound because the conditioning context may diverge from what is seen during training when the ground-truth labels are fed as inputs (teacher forcing). - Goal of professor forcing: Make the generative (free-run) behavior and the teacher-forced behavior match as closely as possible. - Discriminator Details - Input is a behavior sequence `B(x, y, theta)` from the generative RNN that contains the hidden states and outputs. - The training objective is to correctly classify whether or not a behavior sequence is generated using teacher forcing vs. scheduled sampling. - Generator - Standard RNN with MLE training objective and an additional term to fool the discrimator: Change the free-running behavior as to match the teacher-forced behavior while keeping the latter constant. - Second optional another term: Change the teacher-forced behavior to match the free-running behavior. - Like GAN, backprop from discriminator into generator. - Architectures - Generator is a standard GRU Recurrent Neural Network with softmax - Behavior function `B(x, y, theta)` outputs pre-tanh activation of GRU states and tje softmax output - Discriminator: Bidirectional GRU with 3-layer MLP on top - Training trick: To prevent "bad gradients" the authors backprop from the discriminator into the generator only if the classification accuracy is between 75% and 99%. - Trained used Adam optimizer - Experiments - PTB Chracter-Level Modeling: Reduction in test NLL, profesor forcing seem to act as a regularizier. 1.48 BPC - Sequential MNIST: Second-best NLL (79.58) after PixelCNN - Handwriting generation: Professor forcing is better at generating longer sequences than seen during training as per human eval. - Music Synthesis: Human eval significantly better for Professor forcing - Negative Results on word-level modeling: Professor forcing doesn't have any effect. Perhaps because long-term dependencies are more pronounced in character-level modeling. - The authors show using t-SNE that the hidden state distributions actually become more similar when using professor forcing ### Thoughts - Props to the authors for a very clear and well-written paper. This is rarer than it should be :) - It's an intersting idea to also match the states of the RNN instead of just the outputs. Intuitively, matching the outputs should implicitly match the state distribution. I wonder if the authors tried this and it didn't work as expected. - Note from [Ethan Caballero](https://github.com/ethancaballero) about why they chose to match hidden states: It's significantly harder to use GANs on sampled (argmax) output tokens because they are discrete as (as opposed to continuous like the hidden states and their respective softmaxes). They would have had to estimate discrete outputs with policy gradients like in [seqGAN](https://github.com/dennybritz/deeplearning-papernotes/blob/master/notes/seq-gan.md) which is [harder to get to converge](https://www.quora.com/Do-you-have-any-ideas-on-how-to-get-GANs-to-work-with-text), which is why they probably just stuck with the hidden states which already contain info about the discrete sampled outputs (the index of the highest probability in the the distribution) anyway. Professor Forcing method is unique in that one has access to the continuous probability distribution of each token at each timestep of the two sequence generation modes trying to be pushed closer together. Conversely, when applying GANs to pushing real samples and generated samples closer together as is traditionally done in models like seqGAN, one only has access to the next dicrete token (not continuous probability distributions of next token) at each timestep, which prevents straight-forward differentiation (used in professor forcing) from being applied and forces one to use policy gradient estimation. However, there's a chance one might be able to use straight-forward differentiation to train seqGANs in the traditional sampling case if one swaps out each discrete sampled token with its continuous distributional word embedding (from pretrained word2vec, GloVe, etc.), but no one has tried it yet TTBOMK. - I would've liked to see a comparison of the two regularization terms in the generator. The experiments don't make it clear if both or only of them them is used. - I'm guessing that this architecture is quite challenging to train. Woul've liked to see a bit more detail about when/how they trade off the training of discriminator and generator. - Translation is another obvious task to apply this too. I'm interested whether or not this works for seq2seq.
Your comment:
|
[link]
Authors present a method similar to teacher forcing that uses generative adversarial networks to guide training on sequential tasks. This work describes a novel algorithm to ensure the dynamics of an LSTM during inference follows that during training. The motivating example is sampling for a long number of steps at test time while only training on shorter sequences at training time. Experimental results are shown on PTB language modelling, MNIST, handwriting generation and music synthesis. The paper is similar to Generative Adversarial Networks (GAN): in addition to a normal sequence model loss function, the parameters try to “fool” a classifier. That classifier is trying to distinguish generated sequences from the sequence model, from real data. A few Objectives are proposed in section 2.2. The key difference to GAN is the B in equations 1-4. B is a function outputs some statistics of the model, such as the hidden state of the RNN, whereas GAN tries rather to discriminate the actual output sequences. This paper proposes a method for training recurrent neural networks (RNN) in the framework of adversarial training. Since RNNs can be used to generate sequential data, the goal is to optimize the network parameters in such a way that the generated samples are hard to distinguish from real data. This is particularly interesting for RNNs as the classical training criterion only involves the prediction of the next symbol in the sequence. Given a sequence of symbols $x_1, ..., x_t$, the model is trained so as to output $y_t$ as close to $x_{t+1}$ as possible. Training that way does not provide models that are robust during generation, as a mistake at time t potentially makes the prediction at time $t+k$ totally unreliable. This idea is somewhat similar to the idea of computing a sentence-wide loss in the context of encode-decoder translation models. The loss can only be computed after a complete sequence has been generated. |