[link]
Attention mechanisms are a common subcomponent within language models, initially as a part of recurrent models, and more recently as their own form of aggregating information over sequences, independent from the recurrence structure. Attention works by taking as input some sequence of inputs, in the most typical case embedded representations of words in a sentence, and learning a distribution of weights over those representations, which allows the network to aggregate the representations, typically by taking a weighted sum. One effect of using an attention mechanism is that, for each instance being predicted, the network produces this weight distribution over inputs, which intuitively feels like it’s the network demonstrating which input words were most important in constructing its decision. As a result, uses of attention have often been accompanied by examples that show attention distributions for examples, implicitly using them as a form of interpretability or model explanation. This paper has the goal of understanding whether attention distributions can be seen as a valid form of feature importance, and takes the position that they shouldn’t be. At a high level, I think the paper makes some valid criticisms, but ultimately I didn’t find the evidence it presented quite as strong as I would have liked. The paper performs two primary analyses of the attention distributions produced by a trained LSTM model: (1) It calculates the level of correlation between the importance that would be implied by attention weights and the importance as calculated using more canonical gradient-based methods (generally things in the shape of “which words contributed the most towards the prediction being what it was). It finds correlation values that range across random seeds, but are generally centered around 0.5. The paper frames this as a negative result, implying that, in the case where attention was a valid form of importance, the correlation with existing metrics would be higher. I definitely follow the intuition that you would expect there be a significant and positive correlation between methods in this class, but it’s unclear to me what a priori reasoning chooses to draw the threshold on “significant” in a way that makes 0.5 fall below it. It just feels like one of those cases where I could imagine someone showing the same plots and coming to a different interpretation, and it’s not clear to me what criteria support one threshold of magnitude vs another (2) It measures how much it can permute the weights of an attention distribution, and have the prediction made by the network not change in a meaningful way. It does this both by random tests, and also by measuring the maximum adversarial perturbation: the farthest-away distribution (in terms of Jenson-Shannon distance) that still produces a prediction within some epsilon of the original prediction. There are a few concerns I have about this as an analysis. First off, it makes a bit of an assumption that attention can only be a valid form of explanation if it’s a causal mechanism within the model. You could imagine that attention distributions still give you information about the internal state of the model, even if they are just reporting that state rather than directly influencing it. Secondly, it seems possible to me that you could get a relatively high Jenson-Shannon distance from an initial distribution just by permuting the indexes of the low-value weights, and shifting distributional weight between them in a way that doesn’t fundamentally change what the network is primarily attending to. Even if this is not the case in this paper, I’d love to see an example or some kind of quantitative measure showing that the J-S Shannon distances they demonstrate require a substantive change in weight priorities, rather than a trivial one. Another general critique is that the experiments in this paper only focused on attention within a LSTM structure, where the embedding associated with each word isn’t really strictly an embedding of that specific word, but also contains a lot of information about things before and after, because of the nature of a BiLSTM. So, there is some specificity in the embedding corresponding to just that word, but a lot less than in a pure attention model, like some being used in NLP these days, where you’re learning an attention distribution over the raw, non-LSTM-ed representations. In this case, it makes sense that attention would be blurry, and not map exactly to our notions of which words are more important, since the word level representations are themselves already aggregations. I think it’s totally fair to only focus on the LSTM case, but would prefer the paper scoped its claims in better accordance with its empirical results. I feel a bit bad: overall, I really approve of papers like this being done to put a critical empirical frame on ML’s tendency to get conceptually ahead of itself. And, I do think that the evidentiary standard for “prove that X metric isn’t a form of interpretability” shouldn’t be that high, becuase on priors, I would expect most things not to be. I think that they may well be right in their assessment, I would just like a more surefooteded set of analyses and interpretation behind it.
Your comment:
|