[link]
Summary by CodyWild 2 years ago
The idea of the Switch Transformer is to have more parameters available for a network to use, but to only use a small subset of those parameters for each example that's run through the network. This is achieved through a routing scheme, whereby a weighting layer is applied to each token and produces a set of logits/softmax weights over the set of possible experts. The token is then sent to the expert that was given the highest weight. The network is implemented such that different experts can actually live on different devices.
https://i.imgur.com/HEB7cJw.png
This architecture is inspired by previous Mixture of Experts work, which applied a similar scheme, but sent each token through a set of k experts rather than just a single one. This had the ostensible effect of increasing stability and performance, but the authors of this paper argue that using a single expert per token is actually preferable on both of these fronts.
There are a lot of experiments in this paper, and I'd recommend taking a look at them in detail if you're interested, but, at a high level, they found evidence that, compared to models with a comparable amount of parameters they were indeed able to get comparable or better performance with a lower number of FLOPS. It also meant they were able to build up to a trillion-parameter model, without having unreasonable computation requirements.
Some interesting considerations relevant to this approach:
- To keep training speed up, you need to strike the right balance of the number of tokens sent to each expert; in this case, the authors added a loss term to incentivize the division between experts to be roughly uniform
- There was some numerical instability around the expert training procedure if you used float16 data types, so they switched to using float32, but only within the experts themselves, rather than in the rest of the network.
- To regularize this huge of a network, the authors decided to apply dropout, but only within the experts
more
less