[link]
Summary by CodyWild 5 years ago
The Transformer, a somewhat confusingly-named model structure that uses attention mechanisms to aggregate information for understanding or generating data, has been having a real moment in the last year or so, with GPT-2 being only the most well-publicized tip of that iceberg. It has lots of advantages: the obvious attractions of strong performance, as well as the ability to train in parallel across parts of a sequence, which RNNs can’t do because of the need to build up and maintain state. However, a problematic fact about the Transformer approach is how it scales to large sequences of input data. Because attention is based on performing pairwise queries between each point in the data sequence and each other point, to allow for aggregation of information from places throughout the sequence, it scales as O(N^2), because every new element in the sequence needs to be queried by N other ones. This makes it resource-intensive to run transformer models on large architectures.
The Sparse Transformer design proposed in this OpenAI paper tries to cut down on this resource cost by loosening the requirement that, in every attention operation, each element directly pulls information from every other element. In this new system, each point doesn’t get information about each other point in a single operation, but, having two operations such limited operations being chained in a row provides that global visibility. This is done in one of two ways.
(1) The first, called the “strided” version, performs two operations in a row, one masked attention that only looks at the last k timesteps (for example, the last 7), and then a second masked attention that only looks at every kth timestep. So, at the end of the second operation, each point has pulled information from points at checkpoints 7, 14, 21 steps ago, and each of these has pulled information from the window between it and its preceding checkpoint, giving visibility into a full global receptive frame in the course of two operations
(2) The second, called the “fixed” version, uses a similar sort of logic, but instead of having the “window accumulation points” be defined in reference to the point doing the querying, you instead have fixed accumulation points responsible for gathering information from the windows around them. So, using the example given in the paper, if you imagine a window of size 128, and an “accumulation points per window” of 8, then the points in indices 120-128 (say) would have visibility into points 0-128. That represents the first operation, and in the second one, all other points in the sequence pull in information by querying the designated accumulation points for all the windows that aren’t masked for it.
The paper argues that, between these two systems, the Strided system should work best when the data has some inherent periodicity, but I don’t know that I particularly follow that intuition. I have some sense that the important distinction here is that in the strided case you have many points of accumulation, each with not much context, whereas in the fixed case you have very few accumulation points each with a larger window, but I don’t know what performance differences exactly I’d expect these mechanical differences to predict.
This whole project of reducing access to global information seems initially a little counterintuitive, since the whole point of a transformer design, in some sense, was its ability to gain global context in a single layer, as opposed to a convnet needing multiple layers to build receptive field, or a RNN needing to maintain state throughout the sequence. However, I think this paper makes the most sense as a way of interpolating the space between something like a CNN and a full attention design, for the sake of efficiency. With a CNN, you have a fixed kernel, and so as your sequence gets longer, you need to add more and more layers in order for any given point to be able to incorporate into its representation information from the complete other side of the sequence. With a RNN, as your sequence gets longers, you pay the cost of needing to backpropogate state farther. So, by contrast, even though the Sparse Transformer seems to be giving up its signal advantage, it’s instead just trading one constant number of steps to achieve global visibility (1), for another (2, in this paper, but conceptually could be more), but still in a way that’s constant with respect to the length of the data. And, in exchange for this trade, they get very sparse, very masked operations, where many of the multiplications involved in these big query calculations can be ignored, making for faster computation speeds.
On the datasets tried, the Sparse Transformer increased speed, and in fact in I think all cases increased performance - not by much, the performance gain by itself isn’t that dramatic, but in the context of expecting if anything worse performance as a result of limiting model structure, it’s encouraging and interesting that it either stays about the same or possible improves.
more
less