[link]
This paper considers the problem of structured output prediction, in the specific case where the output is a sequence and we represent the sequence as a (conditional) directed graphical model that generates from the first token to the last. The paper starts from the observation that training such models by maximum likelihood (ML) does not reflect well how the model is actually used at test time. Indeed, ML training implies that the model is effectively trained to predict each token conditioned on the previous tokens *from the ground truth* sequence (this is known as "teacher forcing"). Yet, when making a prediction for a new input, the model will actually generate a sequence by generating tokens one after another and conditioning on *its own predicted tokens* instead. So the authors propose a different training procedure, where at training time each *conditioning* ground truth token is sometimes replaced by the model's previous prediction. The choice of replacing the ground truth by the model's prediction is made by "flipping a coin" with some probability, independently for each token. Importantly, the authors propose to start with a high probability of using the ground truth (i.e. start close to ML) and anneal that probability closer to 0, according to some schedule (thus the name Schedule Sampling). Experiments on 3 tasks (image caption generation, constituency parsing and speech recognition) based on neural networks with LSTM units, demonstrate that this approach indeed improves over ML training in terms of the various performance metrics appropriate for each problem, and yields better sequence prediction models. #### My two cents Big fan of this paper. It both identifies an important flaw in how sequential prediction models are currently trained and, most importantly, suggests a solution that is simple yet effective. I also believe that this approach played a non-negligible role in Google's winner system for image caption generation, in the Microsoft COCO competition. My alternative interpretation of why Scheduled Sampling helps is that ML training does not inform the model about the relative quality of the errors it can make. In terms of ML, it is as bad to put high probability on an output sequence that has just 1 token that's wrong, than it is to put the same amount of probability on a sequence that has all tokens wrong. Yet, say for image caption generation, outputting a sentence that is one word away from the ground truth is clearly preferable from making a mistake on a words (something that is also reflected in the performance metrics, such as BLEU). By training the model to be robust to its own mistakes, Scheduled Sampling ensures that errors won't accumulate and makes predictions that are entirely off much less likely. An alternative to Scheduled Sampling is DAgger (Dataset Aggregation: \cite{journals/jmlr/RossGB11}), which briefly put alternates between training the model and adding to the training set examples that mix model predictions and the ground truth. However, Scheduled Sampling has the advantage that there is no need to explicitly create and store that increasingly large dataset of sampled examples, something that isn't appealing for online learning or learning on large datasets. I'm also very curious and interested by one of the direction of future work mentioned in the conclusion: figuring out a way to backprop through the stochastic predictions made by the model. Indeed, as the authors point out, the current algorithm ignores the fact that, by sometimes taking as input its previous prediction, this induces an additional relationship between the model's parameters and its ultimate prediction, a relationship that isn't taken into account during training. To take it into account, you'd need to somehow backpropagate through the stochastic process that generated the previous token prediction. While the work on variational autoencoders has shown that we can backprop through gaussian samples, backpropagating through the sampling of a discrete multinomial distribution is essentially an open problem. I do believe that there is work that tried to tackle propagating through stochastic binary units however, so perhaps that's a start. Anyways, if the authors could make progress on that specific issue, it could be quite useful not just in the context of Schedule Sampling, but possibly in the context of training networks with discrete stochastic units in general!
Your comment:
|
[link]
Google's team developed scheduled sampling as an alternative training procedure to fit RNNs, and they used it in their competition-winning method for image captioning. While I can't argue with the empirical results (so I won't), I was a bit skeptical about the technique at a fundamental level, so I decided to do a bit of math that resulted in this blog post. Overall, I have a suspicion that scheduled sampling is a flawed objective function for unsupervised/generative modelling, and I want to use this post to explain why I think so. I hope the comments section will work this time so people can comment and argue otherwise. Please also shoot me an email if you have more to say. #### Summary of this note - I have a critical look at scheduled sampling as objective function for training RNNs - I show it can lead to pathologies where the RNN learns marginal instead of conditional distributions - I explain why I think adversarial training/generative moment matching offers a better alternative - Lastly, I include a paragraph in which I apologise for being a di*k again. #### Strictly proper scoring rules I've mentioned scoring rules on this blog many times, and my PhD thesis was about them, so saying I'm obsessed with this topic would be a valid observation. But this is important stuff, particularly for unsupervised learning, and particularly as a framework to think about hard concepts like overfitting in generative models. Scoring rules are essentially loss functions for probabilistic models/forecasts. A scoring rule $$S(x,Q)$$ simply measures how bad a probabilistic forecast $Q$ for a variable is in the light of actual observation $x$. In this notation, lower is better. A scoring rule is called strictly proper, if for any $P$, the following holds: $$\underset{Q}{\operatorname{argmax}} \mathbb{E}_{x\sim P}S(x,Q) = P$$ In other words, if you repeatedly sample observations from some true underlying distribution $P$, then the model $Q$ which minimises expected score is $P$. This means that the scoring rule cannot be fooled and that minimising the expected score yields a consistent estimator for $P$. Because I mention consistency, people may dismiss this as a learning theory argument, but it is not. If you are a Bayesian or a deep learning person with no interest in consistency, a scoring rule being strictly proper simply means that it is safe to use it as a loss function. Anything that's not strictly proper is weird and wrong, it will lead to learning the wrong thing. This concept is central in unsupervised learning and generative modelling. Unsupervised learning is all about modelling the probability distribution of data, so it's essential that we have loss functions that can measure the discrepancy between our model $Q$, and the true data distribution $P$ in a consistent way. #### log-likelihood One of the most frequently used strictly proper scoring rule is the logarithmic score: $$S(x,Q) = - \log Q(x)$$ This quantity is also known as the negative log-likelihood. Minimising the expected score in an i.i.d scenario yields maximum likelihood estimation, which is known to be a consistent estimator and has nice properties. Often, the likelihood is impossible to evaluate. Luckily, it is not the only strictly proper scoring rule. In the context of generative models people have used the pseudolikelihood, score matching and moment matching, all of which are examples of strictly proper scoring rules. To recap, any learning method that corresponds to minimising a strictly proper scoring rule is fine, everything else can go horribly wrong, even if we feed it infinite data, it might just learn the wrong thing. #### Scheduled Sampling After successfully establishing myself as a proper-scoring-rule-nazi, let's talk about scheduled sampling (SS). I don't have a lot of space explaining SS in great detail here, only the basic idea. I encourage everyone to read the paper and Hugo's summary above. SS is a new method to train recurrent neural networks(RNNs) to model sequences. I will use character-by-character models of text as an example. Typically, when you train an RNN, you aim to minimise the log predictive likelihood in predicting the next character in each training sentence, given the prefix string of previous characters. This can be thought of as a special case of maximum likelihood learning, and is all fine, you can actually do this properly without approximations. After training, you use the RNN to generate sample sentences in a recursive fashion: assuming you've already generated $n$ characters, you feed that prefix into the RNN, and ask it to predict the $n+1$st character. The $n+1$st character is then added to the prefix to predict the $n+2$th character, and so on. The authors say there is a disconnect between how the model is trained (it's always fed real data) and how it's used (it's always fed synthetic data generated by itself). This, they argue, leads to the RNN being unable to recover from its own mistakes early on in the sentence. To address this, the authors propose an alternative training strategy, where every once in a while, the network is given its own synthetic data instead of real data at training time. More specifically, for each character in the training sentences, we flip a coin to decide whether we feed the character from the real training sentence, or whether to feed the model's own prediction as to what that character would have been. The authors claim this makes the model more robust to recovering from mistakes, which is probably true. As far as I'm concerned, I'm happy as long as the new training procedure corresponds to a strictly proper scoring rule. But in this case, I have a strong feeling that it does not. #### case study: sequence of two variables For sake of simplicity, let's consider using scheduled sampling to learn the joint distribution of a sequence of just two random variables. This is probably the simplest (shortest) time series I can think of. So SS in this case works as follows: For each datapoint train the network to predict the real $x_1$. Then we flip a coin to decide whether to keep $x_1$ from the datapoint, or to replace it with a sample from the model $Q_{x_1}$. Then we train $Q_{x_2\vert x_1}$ on the $(x_1,x_2)$ pair obtained this way. The scoring rule for selective sampling looks something like this: $$ S(Q_{x_1,x_2},(x_1,x_2)) = - (1 - \epsilon) [ \mathbb{E}_{z \sim Q_{x_1}} \log Q_{x_2 \vert x_1}(x_2 \vert z) + \log Q_{x_1}(x_1)] - \epsilon \log Q_{x_2 , x_1}(x_1,x_2),$$ where $\epsilon$ is the probability with wich the true $x_1$ is used. The authors suggest starting training with $\epsilon=1$ and annealing it so that by the end of the training $\epsilon=0$. So as far as the eventual optimum of SS is concerned, we only have to focus on what the first term of the scoring rule does. The second term is the good old log-likelihood so we know that part works. After some math, one can show that scheduled sampling with a fixed $\epsilon$ minimises the following divergence between the true $P$ and the model $Q$: $$D_{SS}[P\|Q] = KL[P_{x_1}\|Q_{x_1}] + (1-\epsilon) \mathbb{E}_{z\sim Q_{x_1}} KL[P_{x_2}\|Q_{x_2\vert x_1=z}] + \epsilon KL[P_{x_2\vert x_1}\|Q_{x_2\vert x_1}]$$ Now, if $\epsilon=1$, we recover the Kullback-Leibler divergence between the joint $P_{x_1,x_2}$ and $Q_{x_1,x_2}$, which is what we expect as it corresponds to maximum likelihood estimation. However, as $\epsilon$ is annealed to $0$, the objective function is somewhat strange, whereby the conditional distribution $Q_{x_2\vert x_1}$ is pushed to model the marginal distribution $P_{x_2}$, instead of $P_{x_2\vert x_1}$ as one would expect. One can therefore see that the factorised $Q^{*} = P_{x_1}P_{x_2}$ minimises this objective function. #### what this means for text modeling Extrapolating from the two variable case to longer sequences, one can see that the scheduled sampling objective would fail if minimised properly until convergence. Consider the case when the $\epsilon\approx 0$ stage is reached in the annealing schedule. Now consider what the RNN has to do to predict the $n$th character in a string during training. It is fed a random prefix string that was generated by itself but never seen any real data. Then the RNN has to give a probabilistic forecast of what the $n$th character in the training sentence is, having seen none of the previous characters in the sentence. The optimal model that minimises this objective would completely ignore all the characters in the sentence so far, but keep a simple linear counter that indexes where it is within the sentence. Then it would emit a character from an index-specific marginal distribution of characters. This is the equivalent of the factorised trivial solution above. Yes, such a model would be better at "recovering from its own mistakes", because at every character it would start independently from what it has generated so far. But this is at the cost of paying no attention whatsoever as to what the prefix of the sentence was. I believe the reason why this trivial behaviour was not observed in the paper is that the authors did not run the optimisation until convergence, and did not implement the full gradient of the objective function, as they discuss in the paper. #### Constructive part of criticism #### What to do instead of SS? So the observed problem was that RNNs trained via maximum likelihood are unable to recover from their own mistakes early on in a sentence, when they are used to generate. `The main reason for the observed problem is that the log-likelihood is a local scoring rule` The local property of scoring rules means that at training time we only ever evaluate the model $Q$ on actually observed datapoints. So if the RNN is faced with a prefix subsequence that was not in the dataset, God knows what it's going to complete that sentence with. The proper (shall I say strictly proper) way to fix this issue is to use |