JAX should be mentioned [1]. It's also from Google and is getting popular these days. Not PyTorch-popular, but the progress and investment seem promissing.
When I first read about JAX I thought it would kill Pytorch, but I'm not sure I can get on with an immutable language for tensor operations in deep learning.
If I have an array `x` and want to set index 0 to 10, I cannot do:
x[0] = 10
I instead have to do:
y = x.at[0].set(10)
I'm sure I could get used to it, but it really puts me off.
Agreed that that's a bit ugly but at least in the ML context you rarely if ever need to do this (personally I only do this on the input to models, where we use pure numpy).
I feel the same. There are probably more ergonomic and generalizable ways to do whatever it is you need to do. Treat it as functional programming and let the XLA compiler handle things.
JAX would be amazing if it had a better debugging story. As it is, if something goes wrong in a transformed function, you lose pretty much all ability to use normal methods of fixing it. It practically just says ‘something went wrong.’
I think JAX is behind pytorch in production usability tho. Its a tool for cutting-edge architecture research but it lacks infrastructure to actually deploy the models in a production environment.
[1] https://github.com/google/jax