Professor Forcing: A New Algorithm for Training Recurrent Networks
Alex Lamb
and
Anirudh Goyal
and
Ying Zhang
and
Saizheng Zhang
and
Aaron Courville
and
Yoshua Bengio
arXiv e-Print archive - 2016 via Local arXiv
Keywords:
stat.ML, cs.LG
First published: 2016/10/27 (8 years ago) Abstract: The Teacher Forcing algorithm trains recurrent networks by supplying observed
sequence values as inputs during training and using the network's own
one-step-ahead predictions to do multi-step sampling. We introduce the
Professor Forcing algorithm, which uses adversarial domain adaptation to
encourage the dynamics of the recurrent network to be the same when training
the network and when sampling from the network over multiple time steps. We
apply Professor Forcing to language modeling, vocal synthesis on raw waveforms,
handwriting generation, and image generation. Empirically we find that
Professor Forcing acts as a regularizer, improving test likelihood on character
level Penn Treebank and sequential MNIST. We also find that the model
qualitatively improves samples, especially when sampling for a large number of
time steps. This is supported by human evaluation of sample quality. Trade-offs
between Professor Forcing and Scheduled Sampling are discussed. We produce
T-SNEs showing that Professor Forcing successfully makes the dynamics of the
network during training and sampling more similar.
Authors present a method similar to teacher forcing that uses generative adversarial networks to guide training on sequential tasks.
This work describes a novel algorithm to ensure the dynamics of an LSTM during inference follows that during training. The motivating example is sampling for a long number of steps at test time while only training on shorter sequences at training time. Experimental results are shown on PTB language modelling, MNIST, handwriting generation and music synthesis.
The paper is similar to Generative Adversarial Networks (GAN): in addition to a normal sequence model loss function, the parameters try to “fool” a classifier. That classifier is trying to distinguish generated sequences from the sequence model, from real data. A few Objectives are proposed in section 2.2. The key difference to GAN is the B in equations 1-4. B is a function outputs some statistics of the model, such as the hidden state of the RNN, whereas GAN tries rather to discriminate the actual output sequences.
This paper proposes a method for training recurrent neural networks (RNN) in the framework of adversarial training. Since RNNs can be used to generate sequential data, the goal is to optimize the network parameters in such a way that the generated samples are hard to distinguish from real data. This is particularly interesting for RNNs as the classical training criterion only involves the prediction of the next symbol in the sequence. Given a sequence of symbols $x_1, ..., x_t$, the model is trained so as to output $y_t$ as close to $x_{t+1}$ as possible. Training that way does not provide models that are robust during generation, as a mistake at time t potentially makes the prediction at time $t+k$ totally unreliable. This idea is somewhat similar to the idea of computing a sentence-wide loss in the context of encode-decoder translation models. The loss can only be computed after a complete sequence has been generated.