It is a very bad idea to handle the KV cache in Jax naively like that. Jax requires static shapes. You're creating dynamic shapes there, causing a ton of recompilation.
I used this to see if something is repeatedly compiled. I.e. I have the code that runs in a loop and you immediately see if something is compiled only once, or every time. (and it produces a lot of output) I'm not saying this is the best way to do it though, it just worked for me.
Just don't use jit in generation and it would be fine. Of course there is some performance penalty but in my experience jit is oversold and the difference is something like ~10-30%.
Also in any case to get optimized code you need flash attention and many other tricks.
People had long forgotten that mobile browsers handle wide content by zooming. If you are making a website but don't bother optimizing it for mobile, leave off the viewport <meta> element.
It's not just the width of the column - there are annotations on certain lines (that appear on a right "margin") that don't show up on mobile. I think that makes it not an easy fix, but to your larger point, this is not very mobile friendly. It looks quite good on a desktop though.
There is some research in accelerating Reinforcement Learning by implementing the simulator on the GPU using Jax. Really neat. I'm curious if this could be done with Mojo, too?
"focuses on the soul of pure functional programming which makes it more cool"
This is tangential to this post's main point but if you're trying for mass adoption this can go badly. Case in point, a hardware company I backed decided to write their code using Haskel like why "because it's cool" and now the people who are trying to modify/work with it have to deal with Haskell vs. a general purpose language like C++ idk...
edit: I also realize most of this code is python but yeah
> deal with Haskell vs. a general purpose language like C++
What's the actual problem? Company decided to use Haskell (which is also a general-purpose language) then hired people who don't know it?
If so, hire a bunch of Pythonistas to work on a Rails project and you'll have similar kind of struggles (and it won't mean that Python or Ruby are somehow bad, it'll be an almost entirely non-technical issue).
If you know Haskell and don't know C++ then C++ will be harder to read. Haskell is definitely less widely used than C++, but that doesn't make it more complex.
To the poster who wrote: "Hey Saurabh, will you be willing to teach me this on a call? I'm willing to pay for it (im not rich, so, dont expect much please). I will be having a lot of questions, mostly related to core concepts of transformers and jax in general."
This is the wrong way to ask for help.
Instead, consider offering your help and time apprenticing and learning along the way. Can't code that well? Write test cases and clean up. Or help blog writing. etc. You certainly have some valuable skill you could trade up.
I mean I’m no Saurabh but that didn’t seem to unreasonable to me? In fact I’ll put my money where my mouth is and offer half an hour for free just to spite you