[link]
Summary by CodyWild 5 years ago
If your goal is to interpret the predictions of neural networks on images, there are a few different ways you can focus your attention. One approach is to try to understand and attach conceptual tags to learnt features, to form a vocabulary with which models can be understood. However, techniques in this family have to content with a number of challenges, from the difficulty in attaching clear concepts to the sheer number of neurons to interpret. An alternate approach, and the one pursued by this paper, is to frame interpretability as a matter of introspecting on *where in an image* the model is pulling information from to make its decision.
This is the question for which hard attention provides an answer: identify where in an image a model is making a decision by learning a meta-model that selects small patches of an image, and then makes a classification decision by applying a network to only those patches which were selected. By definition, if only a discrete set of patches were used for prediction, those were the ones that could be driving the model's decision. This central fact of the model only choosing a discrete set of patches is a key complexity, since the choice to use a patch or not is a binary, discontinuous action, and not something through which one can back-propogate gradients. Saccader, the approach put forward by this paper, proposes an architecture which extracts features from locations within an image, and uses those spatially located features to inform a stochastic policy that selects each patch with some probability. Because reinforcement learning by construction is structured to allow discrete actions, the system as a whole can be trained via policy gradient methods.
https://i.imgur.com/SPK0SLI.png
Diving into a bit more detail: while I don't have a deep familiarity with prior work in this area, my impression is that the notion of using policy gradient to learn a hard attention policy isn't a novel contribution of this work, but rather than its novelty comes from clever engineering done to make that policy easier to learn. The authors cite the problem of sparse reward in learning the policy, which I presume to mean that if you start in more traditional RL fashion by just sampling random patches, most patches will be unclear or useless in providing classification signal, so it will be hard to train well. The Saccader architecture works by extracting localized features in an architecture inspired by the 2019 BagNet paper, which essentially applies very tall and narrow convolutional stacks to spatially small areas of the image. This makes it the case that feature vectors for different overlapping patches can be computed efficiently: instead of rerunning the network again for each patch, it just combined the features from the "tops" of all of the small column networks inside the patch, and used that aggregation as a patch-level feature. These features from the "representation network" were then used in an "attention network," which uses larger receptive field convolutions to create patch-level features that integrated the context of things around them. Once these two sets of features were created, they were fed into the "Saccader cell", which uses them to calculate a distribution over patches which the policy then samples over. The Saccader cell is a simplified memory cell, which sets a value to 1 when a patch has been sampled, and applies a very strong penalization on that patch being sampled on future "glimpses" from the policy (in general, classification is performed by making a number of draws and averaging the logits produced for each patch).
https://i.imgur.com/5pSL0oc.png
I found this paper fairly conceptually clever - I hadn't thought much about using a reinforcement learning setup for classification before - though a bit difficult to follow in its terminology and notation. It's able to perform relatively well on ImageNet, though I'm not steeped enough in that as a benchmark to have an intuitive sense for the paper's claim that their accuracy is meaningfully in the same ballpark as full-image models. One interesting point the paper made was that their system, while limited to small receptive fields for the patch features, can use an entirely different model for mapping patches to logits once the patches are selected, and so can benefit from more powerful generic classification models being tacked onto the end.
more
less