SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient
Lantao Yu
and
Weinan Zhang
and
Jun Wang
and
Yong Yu
arXiv e-Print archive - 2016 via Local arXiv
Keywords:
cs.LG, cs.AI
First published: 2016/09/18 (8 years ago) Abstract: As a new way of training generative models, Generative Adversarial Nets (GAN)
that uses a discriminative model to guide the training of the generative model
has enjoyed considerable success in generating real-valued data. However, it
has limitations when the goal is for generating sequences of discrete tokens. A
major reason lies in that the discrete outputs from the generative model make
it difficult to pass the gradient update from the discriminative model to the
generative model. Also, the discriminative model can only assess a complete
sequence, while for a partially generated sequence, it is non-trivial to
balance its current score and the future one once the entire sequence has been
generated. In this paper, we propose a sequence generation framework, called
SeqGAN, to solve the problems. Modeling the data generator as a stochastic
policy in reinforcement learning (RL), SeqGAN bypasses the generator
differentiation problem by directly performing gradient policy update. The RL
reward signal comes from the GAN discriminator judged on a complete sequence,
and is passed back to the intermediate state-action steps using Monte Carlo
search. Extensive experiments on synthetic data and real-world tasks
demonstrate significant improvements over strong baselines.
Everyone has been thinking about how to apply GANs to discrete sequence data for the past year or so. This paper presents the model that I would guess most people thought of as the first-thing-to-try:
1. Build a recurrent generator model which samples from its softmax outputs at each timestep.
2. Pass sampled sequences to a recurrent discriminator model which distinguishes between sampled sequences and real-data sequences.
3. Train the discriminator under the standard GAN loss.
4. Train the generator with a REINFORCE (policy gradient) objective, where each trajectory is assigned a single episodic reward: the score assigned to the generated sequence by the discriminator.
Sounds hacky, right? We're learning a generator with a high-variance model-free reinforcement learning algorithm, in a very seriously non-stationary environment. (Here the "environment" is a discriminator being jointly learned with the generator.)
There's just one trick in this paper on top of that setup: for non-terminal states, the reward is defined as the *expectation* of the discriminator score after stochastically generating from that state forward. To restate using standard (somewhat sloppy) RL syntax, in different terms than the paper: (under stochastic sequential policy $\pi$, with current state $s_t$, trajectory $\tau_{1:T}$ and discriminator $D(\tau)$)
$$r_t = \mathbb E_{\tau_{t+1:T} \sim \pi(s_t)} \left[ D(\tau_{1:T}) \right]$$
The rewards are estimated via Monte Carlo — i.e., just take the mean of $N$ rollouts from each intermediate state. They claim this helps to reduce variance. That makes intuitive sense, but I don't see any results in the paper demonstrating the effect of varying $N$.
---
Yep, so it turns out that this sort of works.. with a big caveat:
## The big caveat
Graph from appendix:
![](https://www.dropbox.com/s/5fqh6my63sgv5y4/Bildschirmfoto%202016-09-27%20um%2021.34.44.png?raw=1)
SeqGANs don't work without supervised pretraining. Makes sense — with a cold start, the generator just samples a bunch of nonsense and the discriminator overfits. Both the generator and discriminator are pretrained on supervised data in this paper (see Algorithm 1).
I think it must be possible to overcome this with the proper training tricks and enough sweat. But it's probably more worth our time to address the fundamental problem here of developing better RL for structured prediction tasks.
That link for the image doesn't work for me (permissions on dropbox?). You can embed images just by writing a url that ends in a .png or .jpg. Or you can wrap the url in the ![](url) markdown syntax to render them. Like this: ![](https://i.imgur.com/hDvHRwT.png)
Fixed! Thanks.
Hi! I am not able to get what is the oracle model?
your 2. the discriminator is a CNN, not a RNN. Do you think a CNN could also be used as a generator?
Your comment:
You must log in before you can post this comment!
You must log in before you can submit this summary! Your draft will not be saved!