I would like to see a comparison for the inference time compute between a regular transformer and this. Iām assuming token/s is lower since you need to compute the weights of the model for each token prior to the actual attention calculations for the sequence position.
Isn't that figure 5 in the paper? It's for training not inference, but presumably if training is faster then inference would be too. Because they don't increase the dimension of the text tokens when scaling up, which reduces the compute needed for attention. But potentially limits how well the text token attention can keep track of things, because it's got less space for passing things along.