Hacker News new | past | comments | ask | show | jobs | submit login

Hello from the JAX team!

We'd like to take this opportunity to give a shout out to some of the awesome projects folks are building on top of JAX, e.g.,

* Flax, a neural network library for JAX (https://github.com/google/flax)

* Haiku, a neural network library for JAX inspired by Sonnet (https://github.com/deepmind/dm-haiku)

* RLax, a library for building reinforcement learning agents (https://github.com/deepmind/rlax)

* NumPyro, a probabilistic programming library on top of JAX (https://github.com/pyro-ppl/numpyro)

* JAX-MD, a differentiable molecular dynamics package built on top of JAX (https://github.com/google/jax-md)

I've noticed that many JAX libraries (including those from Google) seem to adopt an object-oriented style more similar to Torch/Keras rather than JAX's functional style demonstrated in modules like jax.experimental.stax. This is disappointing since stax is quite clean and these libraries seem to use a lot of hacks to make OO work with JAX. Is there an effort to implement and maintain more full-featured functional libraries in the jax/stax style?

I've been involved w. jax/stax/trax/flax - I think the real issue w. the stax-like functional form is that it gets unwieldy very quickly when dealing w. more complicated models that are natively general-graphs as opposed to simple sequential pipelines that can be trivially mapped to a combinator expression tree. Of course there are many solutions here, but ultimately if you're building an NN library you need to build something that ML researchers actually want to use daily, and that tends to look closer to hackable pytorch-like DSLs rather than higher-order functional code - which often looks elegant but tends to hurt readability and rework speed.

Don't forget Neural Tangents, a high level library for building and running experiments with infinite width neural networks: https://github.com/google/neural-tangents

What about Trax (https://github.com/google/trax) and how does it compare with Flax or Haiku?

Interesting that googlers who are supposed to use Tensorflow are now actively developing a new autograd engine and at least three new DL frameworks on top of it. What do you think about this segmentation?

Good catch, I missed Trax! Trax is a configuration-driven neural network framework focused on sequence model research, as a successor to Tensor2Tensor.

Comparisons are hard in general and I don't have a good answer for you right now, but keep in mind most of these libraries are from researchers openly sharing the codebases they develop for their own work. We see the role of JAX as analogous to NumPy, that is, a common substrate on which folks can build these sorts of tools.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact