Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

JAX is a numerics library combined with autograd for machine learning. you string together operations in python in a functional style, which is then passed into XLA which is google's optimizing compiler that can target cpus, gpus and tpus to generate optimized machine code for those architectures.

numba lets you inline with a subset of python which is then compiled with llvm producing something very similar to what you would get if you applied a bunch of regexes to that subset of python to convert it from python to C. (with special bindings for numpy arrays, since they have special importance in these domains)

numba is specifically targeted at things like core numerical algorithms that are typically coded in C and fortran, and are typically comprised of solely for loops and basic arithmetic. JAX is more targeted at high level machine learning applications where the end user is stringing together more high level numerical algorithms.

i suspect that JAX would be a bad fit for custom computer vision or numerical algorithms that are used outside of the use-case of doing neural networks work.



JAX is actually lower level than deep learning (despite including some specialized constructs) which makes it an almost drop-in replacement for numpy that has the ability to jit your python code.

I am currently doing some tests introducing JAX in a large numerical code base (that was previously using C++ extensions), we are not using autograd nor any deep learning specific functionalities. Having seen actual numbers, I can tell you that JAX on CPU is competitive with C++ but produces more readable code with the added benefit of also running on GPU. However, it does introduces some constraints (array sizes cannot be too dynamic) so, if you are not also planning on also targeting GPU, I would probably focus on numba.


i actually poked around a bit as a contributor a few years ago (before i had to start a real job) and remember it being a thin layer on top of XLA among a few other things. interesting to learn that it is growing into something that people are using as a fully fledged numerical computing library.

also a little bit surprising to see how immature and fragmented the python gpu numerical computing ecosystem is. everybody bags on matlab, but it has been automatically shipping relevant operations over to available gpus for years.


There are also tricks to get around the array shape dynamics. Like padding up your shapes to some common format. Everything between 6 and 10 becomes 10, everything 11-20 pads to 20, etc.

Jax is a great general purpose numerical computing library.


Jax offers a much lower level control, it's almost bare metal, and can be used for all sorts of things besides Deep Learning. I am currently using to implement a better `scipy.optim` library.


I would love to learn more about that library!




Consider applying for YC's Summer 2026 batch! Applications are open till May 4

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: