I am really looking forward for JAX to take over pytorch/cuda over the next years. The whole PTX kerfuffle with Deepseek team shows the value of investing in more low levels approaches to squeeze out the most out of your hardware.
I like JAX but I'm not sure how an ML framework debate like "JAX vs PyTorch" is relevant to DeepSeek/PTX. The JAX API is at a similar level of abstraction to PyTorch [0]. Both are Python libraries and sit a few layers of abstraction above PTX/CUDA and their TPU equivalents.
[0] Although PyTorch arguably encompasses 2 levels, with both a pure functional library like the JAX API, as well as a "neural network" framework on top of it. Whereas JAX doesn't have the latter and leaves that to separate libraries like Flax.
The interesting thing about this comment is that JAX is actually higher-level even than pytorch generally. Since everything is compiled you just express a logcial program and let the compiler (XLA) worry about the rest.
Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?
This has been my bible for performance work internally at Google. Kind of surprised they released it publicly, but I guess they removed all the Gemini-specific details.
The short answer is that tracing is way, way easier to implement in a predictable and reliably performant way. This especially matters for distributed computation and automatic differentiation, two areas where JAX shines.
AST parsing via reflection means your ML compiler needs to re-implement all of Python, which is not a small language. This is a lot of work and hard to do well with abstractions that are not designed for those use-cases. (I believe Julia's whole language auto-diff systems struggle for essential the same reason.)
I literally am a paid ML compiler engineer and I have no idea what this means. You understand that reflection, ala looking in a mirror is about being about to identify a type's type at runtime. It has nothing to do with the AST.
it's quite true. the convergence of all archs to transformers is well documented by karpathy. SSMs were once touted as transformer killers, but increasingly look like just optional supplements.
Nothing fancy. I made these with some pretty simple hand written scripts in javascript rendering to canvas: lots of fiddly little boxes moving around are simpler to script than to hand animate. (If I were to do much more of this I might rewrite these in blender since it has much nicer authoring tooling and export control.)
Shameless request for help: if anybody has experience with seq2seq on TPU, and you want to do a cool project to deploy a world class Pytorch image parsing model to TPU (and do this quickly), please contact me immediately for a well paid and interesting job opportunity at nico [at] mathpix.com.
reply