There is a particularly nice geometric interpretation of attention I just realised recently in a flash of enlightenment, best explained with an interactive Desmos plot (black dot is draggable):
The above assumes the columns of K are normalised but bear with me. K and V together form a vector database. V are the payloads, each row containing a vector of data. K describes the position of these points in space, on the surface of a hypershpere. The query vector describes the query into the database: the vector direction describes the point in space that's being queried, the vector magnitude describes the radius of the query. The result is the weighted average of vectors from V, weighted by their distance from the query vector scaled by the query radius (which has a smooth Gaussian falloff). A recent paper from Nvidia I recommend, which derives a significant speedup by normalising vectors to a hypershpere: https://arxiv.org/abs/2410.01131v1
Attention is a 3 matrix product, s(QK)V where s is softmax. Each matrix has as many rows (Q and V) or columns (K) as many tokens you have in your context. The plot looks at the processing of a single row of Q (predicting a single token from previous ones) called q. q is a 2 element vector and is visualised as the draggable dot (imagine a line from the origin to the dot). The K matrix is shown as green dots, each previous token in the context window is represented as a separate dot. The distance of a blue dot from a corresponding green dot represents how much information from that token gets mixed into the output of the query. The green dots form a hypersphere, a 1D manifold in 2D space. In a real network it would be more like e.g. a 127D manifold in 128D space but the analogy works there as well. You can see how the query gathers information stored on the surface of the manifold by specifying a region and volume of space specified through q's direction and magnitude respectively.
Yeah, I believe this intuition first introduced by the Neural Turing Machine line-of-work and later simplified into AIAYN paper (NTM maintains "external memory" a.k.a. weight_keys, weight_values here).
Disclaimer: these are from my memory, which can be wrong entirely.
https://www.desmos.com/calculator/3rtqsyapxo
The above assumes the columns of K are normalised but bear with me. K and V together form a vector database. V are the payloads, each row containing a vector of data. K describes the position of these points in space, on the surface of a hypershpere. The query vector describes the query into the database: the vector direction describes the point in space that's being queried, the vector magnitude describes the radius of the query. The result is the weighted average of vectors from V, weighted by their distance from the query vector scaled by the query radius (which has a smooth Gaussian falloff). A recent paper from Nvidia I recommend, which derives a significant speedup by normalising vectors to a hypershpere: https://arxiv.org/abs/2410.01131v1