I recently started using JAX for some ion-optics work in accelerator physics. I have found it very very good. The autodiff stuff is magical for doing optimisation work, but even just as a compiled-numpy, I have found it very easy to get highly performant code. For reference, I previously tried roughly the same thing in “numba”, and wasn’t able to get anywhere near the same performance as JAX, even running on the CPU, which I understand is JAX’s weakest backend. By and large I have just written basically idiomatic Python/numpy code — sprinkled a few “vmap”s and “scan”s around, and got great results. I’m very pleased with JAX.
Unfortunately it’s for my work so not currently public. One chunk of it which I would love to open source sometime, is basically just TRANSPORT but modern Python. Despite being a million years old, just second order matrix-style beam optics is actually still a really useful place to start for a lot of stuff. And with it all differentiable, one can do way more kinds of interesting optimisations.
What’s interesting is that the old Fortran source code for TRANSPORT is still out there if one digs, and most of it is taken up by thousands of lines of code that (1) implement a bunch of matrix multiplication that is now just one call to “jnp.einsum()“, and (2) implementing a tedious by-hand differentiation of those same calculations, in a few parameters, which is now basically inferior to just one call to “jax.grad()”. It’s amazing how modern tools can make hard stuff truly trivial.
If you have any interest in single-particle tracking for ion optics. I'd encourage you to check out https://github.com/KATRIN-Experiment/Kassiopeia. It can do both exact and adiabatic tracking of particles (and will solve for B and E-fields in complex electrode geometries). However, as far as I know it only handles static electric fields for now.
Cool, thanks for the link it's always interesting to see what's out there (there's so much interesting code around!).
My quick look at that makes me think it's more like a modern SIMION than a modern TRANSPORT.
I'm more interested at the level of accelerator beamlines (ensembles of magnets with high energy beams), and how to reduce the "impedence mismatch" between that process and the physical design optimization of individual magnets.
Very interesting that it's coming from Google. I did my masters in tokamak simulation, so my first question is about performance. Python is very rarely used in this space just for performance reasons. Even though Python can call out to BLAS or whatever, it's still usually worth it to code in Fortran or C or maybe Julia.
I am doing quite a bit of work with JAX (the Python library used here) in a high-performance numerical computing context.
On GPU/TPU, it is not going to reach perfect 100% hardware usage, but it is going to get close enough (far above vanilla Python performance) and be significantly more productive than alternatives.
That makes it a sweet spot for research (where you will want to tweak things as you go) and extremely complex codes (where you already need to put your full focus on the correctness of the code). I highly recommend it to domain experts who need performance for their research project.
> Even though Python can call out to BLAS or whatever, it's still usually worth it to code in Fortran or C or maybe Julia.
I write in a mix of C++ and Python, and have also dabbled with tokamaks, and I think this is a (common) misunderstanding.
Fundamentally, you are optimizing the speed of project progress. Now, if you don't need your results in real time, e.g. because you are building a simulator which is not in a control loop for instance, you are often better at taking the easy language and ignoring the compute performance of the language. The compute performance of languages is a fixed multiplier, and with numerical code you might see things up to 10x.
But having readable code, which is easy to manipulate and change to test ideas, is speeding up the project progress by such a large factor, that it is hard to keep up with other languages. The reason python is omnipresent in machine learning, is not for a lack of trying of other languages. Python is just very good at allowing you to keep up with a fast-moving field.
The metric being optimized is not just performance, but also the ability to build reasonably performant workflows with arbitrary differentiable (i.e., ML) inputs and outputs.
There's a video a while back I saw where they had a model that could predict instability in plasma before it happened, allowing operators to turn off the machine before the out of control plasma hurts something. [0]
I just noticed this podcast episode on Deep RL for fusion reactors was recently published, if anyone likes this stuff. I have not listened yet, but this podcast in general is great.
Cool project. I would love to explore simulation projects like this but often don't know where to begin. It's partly because the domains are so foreign to me.