Hacker News new | past | comments | ask | show | jobs | submit login

Looks great, but outside Google I do not personally know anyone who uses Jax, and I work in this space.



Not at Google but currently using Jax to leverage TPUs, because AWS's GPU pricing is eye-gougingly expensive. For the lower-end A10 GPUs, the price-per-gpu for a 4 GPU machine is 1.5x the price for a 1 GPU machine, and the price-per-gpu for a 8 GPU machine is 2x the price of a 1 GPU machine! If you want a A100 or H100, the only option is renting an 8 GPU instance. With properly TPU-optimised code you get something like 30-50% cost saving on GCP TPUs compared to AWS (and I say that as someone who otherwise doesn't like Google as a company and would prefer to avoid GCP if there wasn't such a significant cost advantage).


I use it for GPU accelerated signal processing. It really delivers on the promise of "Numpy but for GPU" better than all competing libraries out there.


We've built our startup from scratch on JAX, selling text-to-image model finetuning, and it's given us a consistent edge not only in terms of pure performance but also in terms of "dollars per unit of work"


Is that gain from TPU usage or something else?


Mostly from the tight JAX-TPU integration yeah


Isn't JAX the most widely used framework in the GenAI space? Most companies there use it -- Cohere, Anthropic, CharacterAI, xAI, Midjourney etc.


most of the GenAI players use both PyTorch and JAX, depending on the hardware they are running on. Character, Anthro, Midjourney, etc. are dual shops (they use both). xAI only uses JAX afaik.


just guessing that tech leadership at all of those traces back to Google somehow


Jax trends on papers with code:

https://paperswithcode.com/trends


Was gonna ask "What's that MindSpore thing that seems to be taking the research world by storm?" but I Googled and it's apparently Huawei's open-source AI framework. 1% to 7% market share in 2 years is nothing to sneeze at - that's growth rates similar to Chrome or Facebook in their heyday.

It's telling that Huawei-backed MindSpore can go from 1% to 7% in 2 years, while Google-backed Jax is stuck at 2-3%. Contrary to popular narrative in the Western world, Chinese dominance is alive and well.


>It's telling that Huawei-backed MindSpore can go from 1% to 7% in 2 years, while Google-backed Jax is stuck at 2-3%. Contrary to popular narrative in the Western world, Chinese dominance is alive and well.

MindSpore has an advantage there because of its integrated support for Huawei's Ascend 910B, the only Chinese GPU that comes close to matching the A100. Given the US banned export of A100 and H100s to China, this creates artificial demand for the Ascend 910B chips and the MindSpore framework that utilises them.


No, mindspore rises because of the chip embargo

No one is going to use stuff that one day is cut off supply.

This is one signal why Huawei was listed by Nvidia as competitor in 4 out of 5 categories of areas, in nvidia's earnings


Its meteoric rise started well before the chip embargo. I've looked into it, it liberally borrows ideas from other frameworks, both PyTorch and Jax, and adds some of its own. You lose some of the conceptual purity, but it makes up for it in practical usability, assuming it works as it says on the tin, which it may or may not. PyTorch also has support for Ascend as far as I can tell https://github.com/Ascend/pytorch, so that support does not necessarily explain MindSpore's relative success. Why MindSpore is rising so rapidly is not entirely clear to me. Could be something as simple as preferring a domestic alternative that is adequate to the task and has better documentation in Chinese. Could be cost of compute. Could be both. Nowadays, however, I do agree that the various embargoes would help it (as well as Huawei) a great deal. As a side note I wish Huawei could export its silicon to the West. I bet that'd result in dramatically cheaper compute.


This data might just be unreliable. It had a weird spike in Dec 2021 that looks unusual compared to all the other frameworks.


China publishes a looooootttttt of papers. A lot of it is careerist crap.

To be fair, a lot of US papers are also crap, but Chinese crap research is on another level. There's a reason a lot of top US researchers are Chinese - there's brain drain going on.


When I looked into a random sampling of these uses, my impression was that it was a common kind of project in China to take a common paper (or another repo) and implement it in Mindspore. That accounted for the vast majority of the implementations.


Note that most of Jax’s minuscule share is Google.


I’m in academia and I use jax because it’s closest to translate maths to code.


Same, Jax is extremely popular with the applied math/modeling crowd.


I use it all the time, and there's also a few classes at my uni that use Jax. It's really great for experimentation and research, you can do a lot of things in Jax you just can't in, say, PyTorch.


Like what?


Anytime you want to make something GPU accelerated that doesn't fit as standard operations on tensors. For example, I often write RL environments in Jax, which is something you can't do in PyTorch. There's also things you can do in PyTorch but that would be far more difficult, for example an efficient implementation of MCTS.

I also used Jax a lot for differential equations, not even sure how I would do that with PyTorch.

Basically, Torch is a lot more like a specialization of Numpy for neural networks, while Jax feels a lot more like if you could just write CUDA as Python, and also get the Jacobians (jacs! jax!) and jvp for free (of everything, you can even differentiate hyperparameters through your optimizer which is crazy).

At the end, when you're doing fundamental research and coming up with something new, I think Jax is just better. If all I had to do was implementation, then I would be a happy PyTorch user.


A small addendum: the only people I know who uses Jax are people who work at Google, or people who had a big GCP grant and needed to use TPUs as a result.


That's cool -- but wouldn't it be more constructive to discuss "the ideas" in this package anyways?

For instance, it would be interesting to discern if the design of PyTorch (and their modules) preclude or admit the same sort of visualization tooling? If you have expertise in PyTorch, perhaps you could help answer this sort of question?

JAX's Pytrees are like "immutable structs, with array leaves" -- does PyTorch have a similar concept?


> does PyTorch have a similar concept

of course https://github.com/pytorch/pytorch/blob/main/torch/utils/_py...


Idk if you need that immutability actually. You could probably reconstruct enough to do this kind of viz from the autograd graph, or capture the graph and intermediates in the forward pass using hooks. My hunch is it should be doable.


If JAX had affine_grid() and grid_sample(), I'd be using it instead of PyTorch for my current project.


it would be great if we can have intelligent tools for building neural networks in pytorch.


would a comprehensive object construction platform with schema support and the ability to hookup to a compiler (ie turn object data to code for instance) be a useful tool in this domain?

ex: https://www.youtube.com/watch?v=fPnD6I9w84c

I am the developer, happy to answer questions.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: