[link]
Summary by CodyWild 5 years ago
As per the “holistic” in the paper title, the goal of this work is to take a suite of existing work within semi-supervised learning, and combine many of its ideas into one training pipeline that can (with really impressive empirical success) leverage the advantages of those different ideas.
The core premise of supervised learning is that, given true-label training signal from a small number of labels, you can leverage large amounts of unsupervised data to improve your model. A central intuition of many of these methods is that, even if you don’t know the class of a given sample, you know it *has* a class, and you can develop a loss by pushing your model to predict the class for an example and a modified or perturbed version of that example, since, if you have a prior belief that that modification should not change your true class label, then your unlabeled data point should have the same class prediction both times. Entropy minimization is built off similar notions: although we don’t know a point’s class, we know it must have one, and so we’d like our model to make a prediction that puts more of its weight on a single class, rather than be spread out, since we know the “correct model” will be a very confident prediction of one class, though we don’t know which it is. These methods will give context and a frame of mind for understanding the techniques merged together into the MixMatch approach.
At its very highest level, MixMatch’s goal is to take in a dataset of both labeled and unlabeled data, and produce a training set of inputs, predictions, and (occasionally constructed or modified labels) to calculate a model update loss from.
https://i.imgur.com/6lHQqMD.png
- First, for each unlabeled example in the dataset, we produce K different augmented versions of that image (by cropping it, rotating it, flipping it, etc). This is in the spirit of the consistency loss literature, where you want your model to make the same prediction across augmentations
- Do the same augmentation for each labeled example, but only once per input, rather than k times
- Run all of your augmented examples through your model, and take the average of their predictions. This is based on the idea that the average of the predictions will be a lower variance, more stable pseudo-target to pull each of the individual predictions towards. Also in the spirit of making something more shaped like a real label, they undertake a sharpening step, turning down the temperature of the averaged distribution. This seems like it would have the effect of more confidently pulling the original predictions towards a single “best guess” label
- At this point, we have a set of augmented labeled data, with a true label, and also a set of augmented unlabeled data, with a label based off of an averaged and sharpened best guess from the model over different modifications. At this point, the pipeline uses something called “MixUp” (on which there is a previous paper, so I won’t dive into it too much here), which takes pairs of data points, calculates a convex combination of the inputs, runs it through the model, and uses as the loss-function target a convex combination of the outputs. So, in the simple binary case, if you have a positive and negatively labeled image and sample a combination parameter of 0.75, you have an image that is 0.75 positive, 0.25 negative, and the new label that you’re calculating cross entropy loss against is 0.75.
- MixMatch generates pairs for its MixUp calculation by mixing (heh) labeled and unlabeled data together, and pairing each labeled and unlabeled pair with one observation from the merged set.
At this point, we have combined inputs, and we have combined labels, and we can calculate loss between them
With all of these methods combined, this method takes the previous benchmark of 38% error, for a CIFAR dataset with only 250 labels, and drops that to 11%, which is a pretty astonishing improvement in error rate. After performing an ablation study, they find that MixUp itself, temperature sharpening, and calculating K>1 augmentations of unlabeled data rather than K=1 are the strongest value-adds; it doesn’t appear like there’s that much difference that comes from mixing between unlabeled and labeled for the MixUp pairs.
more
less