[link]
Everyone has been thinking about how to apply GANs to discrete sequence data for the past year or so. This paper presents the model that I would guess most people thought of as the first-thing-to-try: 1. Build a recurrent generator model which samples from its softmax outputs at each timestep. 2. Pass sampled sequences to a recurrent discriminator model which distinguishes between sampled sequences and real-data sequences. 3. Train the discriminator under the standard GAN loss. 4. Train the generator with a REINFORCE (policy gradient) objective, where each trajectory is assigned a single episodic reward: the score assigned to the generated sequence by the discriminator. Sounds hacky, right? We're learning a generator with a high-variance model-free reinforcement learning algorithm, in a very seriously non-stationary environment. (Here the "environment" is a discriminator being jointly learned with the generator.) There's just one trick in this paper on top of that setup: for non-terminal states, the reward is defined as the *expectation* of the discriminator score after stochastically generating from that state forward. To restate using standard (somewhat sloppy) RL syntax, in different terms than the paper: (under stochastic sequential policy $\pi$, with current state $s_t$, trajectory $\tau_{1:T}$ and discriminator $D(\tau)$) $$r_t = \mathbb E_{\tau_{t+1:T} \sim \pi(s_t)} \left[ D(\tau_{1:T}) \right]$$ The rewards are estimated via Monte Carlo — i.e., just take the mean of $N$ rollouts from each intermediate state. They claim this helps to reduce variance. That makes intuitive sense, but I don't see any results in the paper demonstrating the effect of varying $N$. --- Yep, so it turns out that this sort of works.. with a big caveat: ## The big caveat Graph from appendix:  SeqGANs don't work without supervised pretraining. Makes sense — with a cold start, the generator just samples a bunch of nonsense and the discriminator overfits. Both the generator and discriminator are pretrained on supervised data in this paper (see Algorithm 1). I think it must be possible to overcome this with the proper training tricks and enough sweat. But it's probably more worth our time to address the fundamental problem here of developing better RL for structured prediction tasks. ![]()
|
[link]
TLDR; The authors train an Generative Adversarial Network where the generator is an RNN producing discrete tokens. The discriminator is used to provide a reward for each generated sequence (episode) and to train the generator network via via Policy Gradients. The discriminator network is a CNN in the experiments. The authors evaluate their model on a synthetic language modeling task and 3 real world tasks: Chinese poem generation, speech generation and music generation. Seq-GAN outperforms competing approaches (MLE, Schedule Sampling, PG-BLEU) on the synthetic task and outperforms MLE on the real world task based on a BLEU evaluation metric. #### Key Points - Code: https://github.com/LantaoYu/SeqGAN - RL Problem setup: State is already generated partial sequence. Action space is the space of possible tokens to output at the current step. Each episode is a fully generated sequence of fixed length T. - Exposure Bias in the Maximum Likelihood approach: During decoding the model generates the next token based on a series previously generated tokens that it may never have seen during training leading to compounding errors. - A discriminator can provide a reward when no task-specific reward (e.g. BLEU score) is available or when it is expensive to obtain such a reward (e.g. human eval). - The reward is provided by the discriminator at the end of each episode, i.e. when the full sequence is generated. To provide feedback at intermediate steps the rest of the sequence is sampled via Monte Carlo search. - Generator and discriminator are trained alternatively and strategy is defined by hyperparameters g-steps (# of Steps to train generator), d-steps (number of steps to train discriminator with newly generated data) and k (number of epochs to train discriminator with same set of generated data). - Synthetic task: Randomly initialized LSTM as oracle for a language modeling task. 10,000 sequences of length 20. - Hyperparameters g-steps, d-steps and k have a huge impact on training stability and final model performance. Bad settings lead to a model that is barely better than the MLE baseline. #### My notes: - Great paper overall. I also really like the synethtic task idea, I think it's a neat way to compare models. - For the real-world tasks I would've liked to see a comparison to PG-BLEU as they do in the synthetic task. The authors evaluate on BLEU score so I wonder how much difference a direct optimization of the evaluation metric makes. - It seems like SeqGAN outperforms MLE significantly only on the poem generation task, not the other tasks. What about the other baselines on the other tasks? What is it about the poem generation that makes SeqGAN perform so well? ![]() |
[link]
GANs for images have made impressive progress in recent years, reaching ever-higher levels of subjective realism. It’s also interesting to think about domains where the GAN architecture is less of a good fit. An example of one such domain is natural language. As opposed to images, which are made of continuous pixel values, sentences are fundamentally sequences of discrete values: that is, words. In a GAN, when the discriminator makes its assessment of the realness of the image, the gradient for that assessment can be backpropagated through to the pixel level. The discriminator can say “move that pixel just a bit, and this other pixel just a bit, and then I’ll find the image more realistic”. However, there is no smoothly flowing continuous space of words, and, even if you use continuous embeddings of words, it’s still the case that if you tried to apply a small change to a embedding vector, you almost certainly wouldn’t end up with another word, you’d just be somewhere in the middle of nowhere in word space. In short: the discrete nature of language sequences doesn’t allow for gradient flow to propagate backwards through to the generator. The authors of this paper propose a solution: instead of trying to treat their GAN as one big differentiable system, they framed the problem of “generate a sequence that will seem realistic to the discriminator” as a reinforcement learning problem? After all, this property - of your reward just being generated *somewhere* in the environment, not something analytic, not something you can backprop through - is one of the key constraints of reinforcement learning. Here, the more real the discriminator finds your sequence, the higher the reward. One approach to RL, and the one this paper uses, is that of a policy network, where your parametrized network produces a distribution over actions. You can’t update your model to deterministically increase reward, but you can shift around probability in your policy such that your expected reward of following that policy is higher. This key kernel of an idea - GANs for language, but using a policy network framework to get around not having backprop-able loss/reward- gets you most of the way to understanding what these authors did, but it’s still useful to mechanically walk through specifics. https://i.imgur.com/CIFuGCG.png At each step, the “state” is the existing words in the sequence, and the agent’s “action” the choosing of its next word - The Discriminator can only be applied to completed sequences, since it's difficult to determine whether an incoherent half-sentence is realistic language. So, when the agent is trying to calculate the reward of an action at a state, it uses Monte Carlo Tree Search: randomly “rolling out” many possible futures by randomly sampling from the policy, and then taking the average Discriminator judgment of all those futures resulting from each action as being its expected reward - The Generator is a LSTM that produces a softmax over words, which can be interpreted as a policy if it’s sampled from randomly - One of the nice benefits of this approach is that it can work well for cases where we don't have a hand-crafted quality assessment metric, the way we have BLEU score for translation ![]() |