[link]
This paper describes an architecture designed for generating class predictions based on a set of features in situations where you may only have a few examples per class, or, even where you see entirely new classes at test time. Some prior work has approached this problem in ridiculously complex fashion, up to and including training a network to predict the gradient outputs of a meta-network that it thinks would best optimize loss, given a new class. The method of Prototypical Networks prides itself on being much simpler, and more intuitive, so I hope I’ll be able to convey that in this explanation. In order to think about this problem properly, it makes sense to take a few steps back, and think about some fundamental assumptions that underly machine learning. https://i.imgur.com/Q45w0QT.png One very basic one is that you need some notion of similarity between observations in your training set, and potential new observations in your test set, in order to properly generalize. To put it very simplistically, if a test example is very similar to examples of class A that we saw in training, we might predict it to be of class A at testing. But what does it *mean* for two observations to be similar to one another? If you’re using a method like K Nearest Neighbors, you calculate a point’s class identity based on the closest training-set observations to it in Euclidean space, and you assume that nearness in that space corresponds to likelihood of two data points having come the same class. This is useful for the use case of having new classes show up after training, since, well, there isn’t really a training period: the strategy for KNN is just carrying your whole training set around, and, whenever a new test point comes along, calculating it’s closest neighbors among those training-set points. If you see a new class in the wild, all you need to do is add the examples of that class to your group of training set points, and then after a few examples, if your assumptions hold, you’ll be able to predict that class by (hopefully) finding those two or three points as neighbors. But what if some dimensions of your feature space matter much more than others for differentiating between classes? In a simplistic example, you could have twenty features, but, unbeknownst to you, only one is actually useful for separating out your classes, and the other 19 are random. If you use the naive KNN assumption, you wouldn’t expect to perform well here, because you will have distances in these 19 meaningless directions spreading out your points, due to randomness, more than the meaningful dimension spread them out due to belonging to different classes. And what if you want to be able to learn non-linear relationships between your features, which the composability of multi-layer neural networks lends itself well to? In cases like those, the features you were handed may be a woefully suboptimal metric space in which to calculate a kind of similarity that corresponds to differences in class identity, so you’ll just have to strike out for the territories and create a metric space for yourself. That is, at a very high level, what this paper seeks to do: learn a transformation between input features and some vector space, such that distances in that vector space correspond as well as possible to probabilities of belonging to a given output class. You may notice me using “vector space” and “embedding” similarity; they are the same idea: the result of that learned transformation, which represents your input observations as dense vectors in some p-dimensional space, where p is a chosen hyperparameter. What are the concrete learning steps this architecture goes through? 1. During each training episode, sample a subset of classes, and then divide those classes into training examples, and query examples 2. Using a set of weights that are being learned by the network, map the input features of each training example into a vector space. 3. Once all training examples are mapped into the space, calculate a “mean vector” for class A by averaging all of the embeddings of training examples that belong to class A. This is the “prototype” for class A, and once we have it, we can forget the values of the embedded examples that were averaged to create it. This is a nice update on the KNN approach, since the number of parameters we need to carry around to evaluate is only (num-dimensions) * (num-classes), rather than (num-dimensions) * (num-training-examples). 4. Then, for each query example, map it into the embedding space, and use a distance metric in that space to create a softmax over possible classes. (You can just think of a softmax as a network’s predicted probability, it’s a set of floats that add up to 1). 5. Then, you can calculate the (cross-entropy) error between the true output and that softmax prediction vector in the same way as you would for any classification network 6. Add up the prediction loss for all the query examples, and then backpropogate through the network to update your weights The overall effect of this process is to incentivize your network to learn, not necessarily a good prediction function, but a good metric space. The idea is that, if the metric space is good enough, and the classes are conceptually similar to each other (i.e. car vs chair, as opposed to car vs the-meaning-of-life), a space that does well at causing similar observed classes to be close to one another will do the same for classes not seen during training. I admit to not being sufficiently familiar with the datasets used for testing to have a sense for how well this method compares to more fully supervised classification schemes; if anyone does, definitely let me know! But the paper claims to get state of the art results compared to other approaches in this domain of few-shot learning (matching networks, and the aforementioned meta-learning). One interesting note is that the authors found that squared Euclidean distance, when applied within the embedded space, worked meaningfully better than cosine distance (which is a more standard way of measuring distances between vectors, since it measures only angle, rather than magnitude). They suspect that this is because Euclidean distance, but not cosine distance belongs to a category of divergence/distance metrics (called Bregman Divergences) that have a special set of properties such that the point closest on aggregate to all points in a cluster is the average of all those points. If you want to dive way deep into the minutia on this point, I found this blog post quite good: http://mark.reid.name/blog/meet-the-bregman-divergences.html |