
JAX: Numpy with Gradients, GPUs and TPUs - one-more-minute
https://github.com/google/jax
======
chrisaycock
For anyone else outside of machine learning who was wondering what all of this
is, here is my best explanation:

The inferencing phase of a neural network attempts to minimize _error_ or
_loss_ as defined by the user. This is done by iteratively applying gradient
descent to the error function. Thus, the error function must have a known
derivative, which can be difficult if the function has loops and conditionals.

Autograd is a software package that produces derivatives of a function
automatically. It creates a computation graph of the user-defined code from
which it can determine the derivative:

[https://en.wikipedia.org/wiki/Automatic_differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)

XLA is a JIT from TensorFlow that compiles common array functions. The JAX
project from this GitHub page brings JIT optimizations to Autograd's automatic
differentiation. That will speed-up the error function when inferring a neural
network. Neat!

(I would be grateful for any corrections to my above explanation as I am not
an expert in ML.)

~~~
twtw
Another small correction: every instance of "inference" in your comment should
probably be replaced with "training." It's the training phase the involves
running gradient descent of various flavors to optimize the network
parameters.

~~~
wenc
This is an important point.

Inferencing means “to predict” (I’m not sure when this terminology became
popular; a few years ago most of us were just using the word predict)

Once trained, a model no longer requires derivatives. It’s more or less a
function evaluation, which can be done on plain CPUs.

~~~
conistonwater
It's from statistical inference, e.g., when the goal is to find the values of
a model's parameters that match the sample. So if the model is y = f(x,
params), inference gives you params, and prediction gives you y for a value of
x that you haven't seen before.

Also, shouldn't the verb be _inferring_?

~~~
wenc
Except in AI terminology[1], it seems that _inference_ doesn't mean outputting
params. It means outputting y.

Yes, I believe the verb is "inferring".

[1] [https://blogs.nvidia.com/blog/2016/08/22/difference-deep-
lea...](https://blogs.nvidia.com/blog/2016/08/22/difference-deep-learning-
training-inference-ai/)

~~~
conistonwater
You're probably right then. I find AI nomenclature to be a bit of a mess.

------
duckworthd
My $0.02.

I've been using JAX for a while now. A paper I'm an author on
([https://arxiv.org/abs/1806.09597](https://arxiv.org/abs/1806.09597), a
follow-up to
[https://news.ycombinator.com/item?id=18633215](https://news.ycombinator.com/item?id=18633215))
resulted in an algorithm that required taking second-derivatives on a per-
example basis. This is extremely difficult in TensorFlow, but with JAX it was
a 2-liner. Even better, it's _super_ fast, thanks to XLA's compile-to-GPU and
JAX's auto-batching mechanics.

I highly recommend JAX to power users. It's nowhere near as feature-complete
from a neural network sense as, say, PyTorch, but it is very good at what it
does, and its core developers are second to none in responsiveness.

~~~
albertzeyer
Why is that difficult in TensorFlow? Wouldn’t you just call tf.gradients on
the output of tf.gradients? At least that is what I usually do and it’s very
simple.

~~~
duckworthd
There are two reasons,

1) Take a closer look at tf.hessian()'s implementation. Notice the tf.while()
loop, calculating tf.gradient() for each coordinate of the first gradient?
That doesn't happen in parallel! Plus, you need to know about TensorArrays!

2) Per-example gradients means an outside tf.while() loop over each example.
Another linear slow down!

The difficulty I refer to is in wrangling conditionals in TF, and the
trickiness of obtaining partial derivatives (rather than total derivatives)
which JAX makes trivial.

There may be a

~~~
albertzeyer
tf.while code will run in parallel as far as possible (if it does not depend
on each other) (up to some configurable degree).

~~~
duckworthd
Yes! Though this answer has some subtlety. tf.while() will run several
iterations in parallel but this is not the same as _batching_ those same
iterations. For that, you'll need to use the experimental parallel_for feature
[0]. Using this should get you into roughly the same speed as JAX.

Tricky!

[0]:
[https://github.com/tensorflow/tensorflow/tree/b3e00739468080...](https://github.com/tensorflow/tensorflow/tree/b3e007394680801113f492fa1f5a9784e8502f19/tensorflow/python/ops/parallel_for)

------
buildbot
I wonder how performance compares to cupy :
[https://cupy.chainer.org](https://cupy.chainer.org)

Seems a little limited in terms of supported operations for autograd compared
to Chainer:[https://chainer.org](https://chainer.org) or
Flux:[http://fluxml.ai](http://fluxml.ai)

Really cool to see though! XLA/TPU support is awesome.

~~~
alexbw
Actually, advanced autodiff is one of its intended points of, er,
differentiation :). The authors wrote the original Autograd package [0],
released in 2014, that led to “autograd” becoming used as a generic term in
PyTorch and other packages. JAX has all of the autodiff operations that
Autograd does, including `grad`, `vjp`, `jvp`, etc. We’re working on the
number of supported NumPy ops, which is limited right now, but it’s early
days. Try it out, we’re really excited to see what you build with it!

0: [https://github.com/hips/autograd](https://github.com/hips/autograd)

------
snackematician
Very cool! I love autograd, it had tape-based autodiff way before pytorch, and
the way it wraps numpy is much more convenient than tensorflow/pytorch. Been
wanting GPU support in autograd for years now, so am very happy to see this.

I have some academic software
([https://github.com/popgenmethods/momi2](https://github.com/popgenmethods/momi2))
that uses autograd, was planning to port it to pytorch since it's better
supported/maintained, but now I'll have to consider jax. Though I'm a little
worried about the maturity of the project, seems like the numpy/scipy coverage
is not all the way there yet. Then again, it would be fun to contribute back
to JAX, I did contribute a couple PRs to autograd back in the day so I think I
could jump right into it...

------
whoisnnamdi
This is very cool, and I can see all sorts of use cases where a tool like this
could be valuable.

Definitely will try to keep in mind that a tool for fast differentiation of
arbitrary functions exists out in the world when starting my next project

Thanks for posting!

------
hyperpallium
> JAX can automatically differentiate native Python ... functions

sympy can differentiate functions, but they have to be set up properly. How
can JAX differentiate native functions?

(Or do they mean numerical differentiation, like a finite difference
estimation?)

~~~
cgearhart
Automatic Differentiation is an algorithm to efficiently compute the _value_
of the derivative of a function implemented by some arbitrary code. It does
not use numerical approximation; it combines algorithmic cleverness with a
table of analytic derivatives for elementary functions. Despite reliance on
analytic derivatives, it does not compute the analytic derivative of the
function–it just computes the _value_ of the derivative for some particular
input.

In order to differentiate native functions you have to accept a fairly loose
definition of "derivative" that basically assumes you can ignore
discontinuities in functions. For example, in the function below we can say
that the derivative is piecewise continuous with a value of 0 for all x < 0,
and value 1 for all x >= 0. AD extends this idea to "each output produced by a
function follows a single path, so the corresponding derivative follows the
same path".

``` def f(x): if x >= 0: return x else: return 0 ```

AD cannot magically determine the derivative of `lambda x: exp(x)` or
similar–it needs a lookup table for elementary functions, for example:
[https://github.com/HIPS/autograd/blob/304552b1ba42e42bce97f0...](https://github.com/HIPS/autograd/blob/304552b1ba42e42bce97f01c6be16e64006c8323/autograd/numpy/numpy_jvps.py#L94)
However, AD does support differentiating through program flow control
including function calls, loops, conditionals, etc. subject to the caveats
above, which is much more difficult to do analytically.

The idea has now been generalized to all linear operators in functional
programming
[https://www.youtube.com/watch?v=MmkNSsGAZhw&feature=youtu.be](https://www.youtube.com/watch?v=MmkNSsGAZhw&feature=youtu.be)
and the associated paper
[https://dl.acm.org/citation.cfm?doid=3243631.3236765](https://dl.acm.org/citation.cfm?doid=3243631.3236765)

~~~
hyperpallium
After checking the wiki link: Does it internally construct a kind of AST for
the symbolic derivative, but only returns particular values, not the
derivative itself? i.e. Since the new function returns derivative values, it
seems it must itself be the derivative...

Looking at your github link, multiplication (product rule) doesn't seem to be
handled there (only `def_linear(anp.multiply)`).

Would an implementation be something really straightforward, like:

    
    
        *(f(x),g(x)): f(x)*g'(x) + f'(x)*g(x)
    

i.e. it internally constructs an AST of the derivative, but only returns
results at specific points, not the AST itself.

(Actually, the wiki eg for Forward Accumulation
([https://wikipedia.org/wiki/Automatic_differentiation#Forward...](https://wikipedia.org/wiki/Automatic_differentiation#Forward_accumulation))
does include a product differentiated in this way, so I guess I got it right).

~~~
yorwba
> multiplication (product rule) doesn't seem to be handled there (only
> `def_linear(anp.multiply)`)

 _def_linear_ is what handles the product rule. Other product operations (like
anp.cross, anp.inner etc.) are implemented the same way. It's called "linear"
because products are multi-linear functions, i.e. they are linear in each
individual argument and you can get the derivative with respect to each
parameter by simple substitution. (x+Δx)y - xy = Δxy. Together with the chain-
rule for multi-parameter functions, the classic product rule falls out for
free.

------
lostmsu
What does it do, if the function is not differentiable? Most of the commonly
seen functions in real code aren't.

------
p1esk
How does it compare to CuPy (in terms of Numpy compatibility)?

