Hacker News new | past | comments | ask | show | jobs | submit login

JAX should be mentioned [1]. It's also from Google and is getting popular these days. Not PyTorch-popular, but the progress and investment seem promissing.

[1] https://github.com/google/jax




When I first read about JAX I thought it would kill Pytorch, but I'm not sure I can get on with an immutable language for tensor operations in deep learning.

If I have an array `x` and want to set index 0 to 10, I cannot do:

  x[0] = 10
I instead have to do:

  y = x.at[0].set(10)
I'm sure I could get used to it, but it really puts me off.


Agreed that that's a bit ugly but at least in the ML context you rarely if ever need to do this (personally I only do this on the input to models, where we use pure numpy).


I feel the same. There are probably more ergonomic and generalizable ways to do whatever it is you need to do. Treat it as functional programming and let the XLA compiler handle things.


Looks like Java code. It's doa


I disagree.


JAX is the Arch Linux of machine learning scientists


The underlying concept of JAX, function transformations, is very powerful and elegant.

PyTorch 2.0 has gotten a similar underlying feature for torch.compile now. https://pytorch.org/get-started/pytorch-2.0/


JAX would be amazing if it had a better debugging story. As it is, if something goes wrong in a transformed function, you lose pretty much all ability to use normal methods of fixing it. It practically just says ‘something went wrong.’


There is now jax.debug.print() and jax.debug.breakpoint(). It’s a start.

https://jax.readthedocs.io/en/latest/debugging/print_breakpo...


I think JAX is behind pytorch in production usability tho. Its a tool for cutting-edge architecture research but it lacks infrastructure to actually deploy the models in a production environment.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: