Hacker News new | past | comments | ask | show | jobs | submit login
Transformers Represent Belief State Geometry in Their Residual Stream (lesswrong.com)
85 points by supriyo-biswas 5 months ago | hide | past | favorite | 19 comments



My summary: Anything which drives predictions of "the next token" must create a model of "how the world works" AND a model for "what state is the world in right now". The authors train a transformer and demonstrate that it creates a structure that represents both of these.


Crucially it will tend to find the simplest such representation that still solves the problem. This is why ultimately the model is only sufficient to solve problems that it was trained to solve.

Simplest in terms of optimization, by the way.


What do you mean by simplest in term of optimization?

I get it find solution that are easy for SGD or Adam optimizer to find.

But why would such solution be less simple than other ?


(I could be wrong here. Please correct me if that's the case.)

I think the comment you're replying to means exactly what you're saying, which is that it will find solutions which are "easy" to find for the optimizer, and therefore solutions which are simple to achieve through the convergence of some optimizer.


Wonder if it can be applied to financial markets and make short term options trading profitable.


Interesting line of work, but as (too?) often on LessWrong, there is no literature review or even a related work section to go back to seminal research or dig deeper into the topic. It might just be me, but I find this really frustrating when I'm unable to hop references to skim a field.

Does anyone have cool links or references to share about related work?


The lack of review can be obviated by googling, tho it's exacerbated by lack of publishing in general for this domain, likely cause it's not work done at a classic academic department, and no motive to publish for citation / impact factor reasons.

Also I find it refreshing that it's a very direct and practical approach to presenting results, which is missing in the scientific formula these days.


This is a blog post, not a paper. Of course it does not have literature review etc.

Now I don't work in the field of transformers, but generally when I see a piece of original work being presented as a blog post, the most reasonable conclusion is that the work is not sufficiently interesting to be a paper.

Good scientific blogs are the ones that summarize papers after they have been published, or at least written up and posted in full technical detail on a preprint server.


Super cool!

Maybe I missed it, but it didn’t seem to me like the authors presented any reasoning about what makes this capability particular to transformers… especially on the relatively small HMMs, it’s not clear to me that we wouldn’t see the same behavior with some other model performing prediction, but maybe I missed that.

Obviously it makes sense to focus on transformers given their ubiquity in SotA research, but it would be interesting to see how “special” this property is given that there is an implicit suggestion that it will continue to be true on larger problems which currently are primarily only solved with transformers.

On the other hand, we have things like selective state space models (most notably, Mamba) which very explicitly of course represent system state within the architecture of the model itself, and are promising alternatives to transformers.

Anyways, very cool post!


> Maybe I missed it, but it didn’t seem to me like the authors presented any reasoning about what makes this capability particular to transformers… especially on the relatively small HMMs, it’s not clear to me that we wouldn’t see the same behavior with some other model performing prediction, but maybe I missed that.

I don't believe they are claiming that?

There is no reason to think that that would be the case (Transformers are in general grossly overrated in terms of magic pixie dust), the fundamental theory here implies that any predictor achieving high performance is going to have to do this, and in the comments, people point out examples elsewhere like with RNNs where researchers have read off the beliefs from the hidden state and found the same thing: https://www.lesswrong.com/posts/gTZ2SxesbHckJ3CkF/transforme...


"A final theoretical note about Computational Mechanics and the theory presented here: because Computational Mechanics is not contingent on the specifics of transformer architectures and is a well-developed first-principles framework, we can apply this framework to any optimal predictor, not just transformers."


Theoretically, HMMs are "models of the world" and transformers are approximations of HMMs in the "forward" algorithm.

Something seems sus about how the linear projection ended up exactly in the same shape as their prediction. Also that their projection seems to stay in the same shape throughout training. Typically, projections look like they "spin around" as they move from a random point cloud to the separated shapes, but I have not done experiments on transformers and it's unclear what they mean by projection.


yes, the projection is possibly responsible for it looking like a simplex/triangle since it's a probability distribution over 3 states.

another individual seem to have asked that same question in the comment section of that article and they wrote another article with the author after a lot of back and forth:

https://www.lesswrong.com/posts/mBw7nc4ipdyeeEpWs/why-would-...


Thank you, that's a perfect follow-up piece, makes the whole thing much clearer.


So training a transformer is actually training a Bayesian model of beliefs about some probabilities? Makes sense.

If we could figure out how to train a ‘fetch this fact/concept/next few tokens’ from memory or from a sat solver instead of recalling from a lossy bayes net…


This article describes an experiment that uses a very simple "world" with three states. Two of the states emit "0" and "1", respectively, and third is a little more tricky - it emits either "0" or "1" randomly. The transformer attempts to predict the next token, and they show that it also attempts to guess the current (hidden) state. So far, so good.

A natural interpretation of this experiment is that there is a "timeline" identical to the stream of tokens - that is, the time coordinate is given by the current offset into the stream. Each token is evidence of the world's current state at the time corresponding to that offset.

But there's no particular reason to call that axis "time." We could just as easily treat it as an x-coordinate, and there's no time at all. The next token is just the one to the right.

Real stories often don't have such a simple mapping. The offset into the text stream is a different dimension than the story's timeline. A "world model" worth its salt would have to attempt to put the various events described within the text stream into some kind of partial order, using various clues about how the events are described. (For a simple experiment, the events could be timestamped to make it easier.)

It would be interesting to know how transformers represent the "current time" and how that relates to other relevant times in a story. I wonder if there's any research about that yet?

Also, could they learn more sophisticated representations from tricky stories involving time travel?


Nice article (paper?).

I was aware of JP crutchfield work on statistics for complexity theory for a while now but never thought of making a connection to deep learning even though his framework is generic enough to apply to any predictor.

The question though, which is mentioned as a next step. is how would you go around modeling the data generating process of the "the real world" or of the training datasets of current models even older ones like gpt2 ?

I have a guess that meta-process state-space size is going to increase exponentially making this unusable.

The good news is that computational mechanics is a treasure trove so there's room for a lot of papers to explore this subject applied to deep learning especially and I'm excited about the intersection with the work we've recently seen about mechanistic interpretability.


In this case, the curse of dimensionality is working in our favor, as the volume of an n-dimensional simplex also grows exponentially, so it is perfectly suited to hold the exponential state space of the mixed-state presentation.

Updating from one mixed-state to another just involves updating for each pure state in the mixture and then mixing the resulting states together, so everything is neatly linear.

One way to represent the training data would be to have one state per token position that emits the corresponding token and advances to the next position, or to a randomly sampled other document if there's no next position. That matches the way LLMs are actually trained, but of course typical transformers have a much smaller residual stream dimension than the number of training tokens, so it needs to conflate some states, and synchronizing with the training data is also not what we want models to do, otherwise we would be using infini-gram https://arxiv.org/abs/2401.17377 to regurgitate exact matches.


Thanks for this, it provides an intriguing approach to thinking about transformers (or predictors in general)!

For extracting the fractal from the residual stream, did I understand it correctly as follows: You repeatedly sample the transformer, each time recording the actual internal state of the HMM and the (higher-dimensional) residual stream. Then you perform a linear regression to obtain a projection matrix from residual stream vector to HMM state vector.

If so, then doesn't that risk "finding" something that isn't necessarily there? While I think/agree that the structure of the mixed state representation is obviously represented in the transformer in this case, in general I don't think that, strictly speaking, finding a particular kind of structure when projecting transformer "state" into known world "state" is proof that the transformer models the world states and its beliefs about the world states in that same way. Think "correlation is not causation". Maybe this is splitting hairs (because, in effect, what does it matter how exactly the transformer "works" when we can "see" the expected mixed state structure inside it), but I am slightly concerned that we introduce our knowledge of the world through the linear regression.

Like, consider a world with two indistinguishable states (among others), and a predictor that (noisily) models those two with just one equivalent state. Wouldn't the linear regression/projection of predictor states into world states risk "discovering" the two world states in the predictor, which don't actually exist there in isolation at all?

Again, I'm not doubting the conclusions/explanation of how, in the article, that transformer models that world. I am only hypothesizing that, for more complex examples with more "messy" worlds, looking for the best projection into the known world states is dangerous: It presupposes that the world states form a true subspace of the residual stream states (or equivalent).

Would be happy to be convinced that there is something deeper that I'm missing here. :)




Consider applying for YC's W25 batch! Applications are open till Nov 12.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: