[link]
Summary by CodyWild 3 years ago
This was an amusingly-timed paper for me to read, because just yesterday I was listening to a different paper summary where the presenter offhandedly mentioned the idea of compressing the sequence length in Transformers through subsequent layers (the way a ConvNet does pooling to a smaller spatial dimension in the course of learning), and it made me wonder why I hadn't heard much about that as an approach. And, lo, I came on this paper in my list the next day, which does exactly that.
As a refresher, Transformers work by starting out with one embedding per token in the first layer, and, on each subsequent layer, they create new representations for each token by calculating an attention mechanism over all tokens in the prior layer. This means you have one representation per token for the full sequence length, and for the full depth of the network. In addition, you typically have a CLS token that isn't connected to any particular word, but is the designated place where sequence-level representations aggregate and are used for downstream tasks.
This paper notices that many applications of trained transformers care primarily about that aggregated representation, rather than precise per-word representations. For cases where that's true, you're spending a lot of computation power on continually calculating the SeqLength^2 attention maps in later layers, when they might not be bringing you that much value in your downstream transfer tasks.
A central reason why you do generally need per-token representations in training Transformers, though, even if your downstream tasks need them less, is that the canonical Masked Language Model and newer ELECTRA loss functions require token-level predictions for the specific tokens being masked. To accommodate this need, the authors of this paper structure their "Funnel" Transformer as more of an hourglass. It turns it into basically a VAE-esque Encoder/Decoder structure, where attention downsampling layers reduce the length of the internal representation down, and then a "decoder" amplifies it back to the full sequence size, so you have one representation per token for training purposes (more on the exact way this works in a bit). The nifty thing here is that, for downstream tasks, you can chop off the decoder, and be left with a network with comparatively less computation cost per layer of depth.
https://i.imgur.com/WC0VQXi.png
The exact mechanisms of downsampling and upsampling in this paper are quite clever.
To perform downsampling at a given attention layer, you take a sequence of representations h, and downsampling it to h' of half the size by mean-pooling adjacent tokens. However, in the attention calculation, you only use h' for the queries, and use the full sequence h for the keys and values. Essentially, this means that you have an attention layer where the downsampled representations attend to and pull information from the full scope of the (non-downsampled) representations of the layer below. This means you have a much more flexible downsampling operation, since the attention mechanism can choose to pull information into the downsampled representation, rather than it being calculated automatically by a pooling operation
The paper inflates the bottleneck-ed representations back up to the full sequence length by first tiling the downsampled representation (for example, if you had downsampled from 20 to 5, you would tile the first representation 4 times, then the second representation 4 times, and so on until you hit 20). That tiled representation, which can roughly be though to represent a large region of the sequence, is then added, ResNet-style, to the full-length sequence of representations that came out of the first attention layer, essentially combining shallow token-level representations with deep region-level representations. This aggregated representation is then used for token-level loss prediction
The authors benchmark again common baseline models, using deeper models with fewer tokens per layer, and find that they can reach similar or higher levels of performance with fewer FLOPs on text aggregation tasks. They fall short of full-sequence models for tasks that require strong per-token representations, which fits with my expectation.
more
less