Hacker News new | past | comments | ask | show | jobs | submit login
JAX: Numpy with Gradients, GPUs and TPUs (github.com/google)
132 points by one-more-minute on Dec 8, 2018 | hide | past | favorite | 29 comments

For anyone else outside of machine learning who was wondering what all of this is, here is my best explanation:

The inferencing phase of a neural network attempts to minimize error or loss as defined by the user. This is done by iteratively applying gradient descent to the error function. Thus, the error function must have a known derivative, which can be difficult if the function has loops and conditionals.

Autograd is a software package that produces derivatives of a function automatically. It creates a computation graph of the user-defined code from which it can determine the derivative:


XLA is a JIT from TensorFlow that compiles common array functions. The JAX project from this GitHub page brings JIT optimizations to Autograd's automatic differentiation. That will speed-up the error function when inferring a neural network. Neat!

(I would be grateful for any corrections to my above explanation as I am not an expert in ML.)

Another small correction: every instance of "inference" in your comment should probably be replaced with "training." It's the training phase the involves running gradient descent of various flavors to optimize the network parameters.

This is an important point.

Inferencing means “to predict” (I’m not sure when this terminology became popular; a few years ago most of us were just using the word predict)

Once trained, a model no longer requires derivatives. It’s more or less a function evaluation, which can be done on plain CPUs.

It's from statistical inference, e.g., when the goal is to find the values of a model's parameters that match the sample. So if the model is y = f(x, params), inference gives you params, and prediction gives you y for a value of x that you haven't seen before.

Also, shouldn't the verb be inferring?

More than that, inference usually refers to acts of decision making or evidence evaluation, like testing hypotheses or interpreting confidence/credible intervals.

Except in AI terminology[1], it seems that inference doesn't mean outputting params. It means outputting y.

Yes, I believe the verb is "inferring".

[1] https://blogs.nvidia.com/blog/2016/08/22/difference-deep-lea...

You're probably right then. I find AI nomenclature to be a bit of a mess.

Small correction: XLA is a compiler in the more general sense, not a JIT specifically. (It's a domain-specific compiler focused on linear algebra.). JAX uses XLA to do JIT compilation.

I think this is more exploiting XLA to speedup autograd than for deep learning. You would generally use tensorflow for actual training, I've never encountered any situations where autograd had to be used during training.

My $0.02.

I've been using JAX for a while now. A paper I'm an author on (https://arxiv.org/abs/1806.09597, a follow-up to https://news.ycombinator.com/item?id=18633215) resulted in an algorithm that required taking second-derivatives on a per-example basis. This is extremely difficult in TensorFlow, but with JAX it was a 2-liner. Even better, it's _super_ fast, thanks to XLA's compile-to-GPU and JAX's auto-batching mechanics.

I highly recommend JAX to power users. It's nowhere near as feature-complete from a neural network sense as, say, PyTorch, but it is very good at what it does, and its core developers are second to none in responsiveness.


About this time last year I had an optimization problem involving linear combinations of Voronoi cell centroids and I basically used random search (I can't remember why differential evolution didn't work).

I don't know much, say, Keras, but I think I understand how to implement backprop-like behaviors with Autograd. Maybe even sandwich concept-specific models between fully-connected linear+sigmoid layers to give them some oomph.

Why is that difficult in TensorFlow? Wouldn’t you just call tf.gradients on the output of tf.gradients? At least that is what I usually do and it’s very simple.

There are two reasons,

1) Take a closer look at tf.hessian()'s implementation. Notice the tf.while() loop, calculating tf.gradient() for each coordinate of the first gradient? That doesn't happen in parallel! Plus, you need to know about TensorArrays!

2) Per-example gradients means an outside tf.while() loop over each example. Another linear slow down!

The difficulty I refer to is in wrangling conditionals in TF, and the trickiness of obtaining partial derivatives (rather than total derivatives) which JAX makes trivial.

There may be a

tf.while code will run in parallel as far as possible (if it does not depend on each other) (up to some configurable degree).

Yes! Though this answer has some subtlety. tf.while() will run several iterations in parallel but this is not the same as _batching_ those same iterations. For that, you'll need to use the experimental parallel_for feature [0]. Using this should get you into roughly the same speed as JAX.


[0]: https://github.com/tensorflow/tensorflow/tree/b3e00739468080...

I wonder how performance compares to cupy : https://cupy.chainer.org

Seems a little limited in terms of supported operations for autograd compared to Chainer:https://chainer.org or Flux:http://fluxml.ai

Really cool to see though! XLA/TPU support is awesome.

Actually, advanced autodiff is one of its intended points of, er, differentiation :). The authors wrote the original Autograd package [0], released in 2014, that led to “autograd” becoming used as a generic term in PyTorch and other packages. JAX has all of the autodiff operations that Autograd does, including `grad`, `vjp`, `jvp`, etc. We’re working on the number of supported NumPy ops, which is limited right now, but it’s early days. Try it out, we’re really excited to see what you build with it!

0: https://github.com/hips/autograd

Very cool! I love autograd, it had tape-based autodiff way before pytorch, and the way it wraps numpy is much more convenient than tensorflow/pytorch. Been wanting GPU support in autograd for years now, so am very happy to see this.

I have some academic software (https://github.com/popgenmethods/momi2) that uses autograd, was planning to port it to pytorch since it's better supported/maintained, but now I'll have to consider jax. Though I'm a little worried about the maturity of the project, seems like the numpy/scipy coverage is not all the way there yet. Then again, it would be fun to contribute back to JAX, I did contribute a couple PRs to autograd back in the day so I think I could jump right into it...

This is very cool, and I can see all sorts of use cases where a tool like this could be valuable.

Definitely will try to keep in mind that a tool for fast differentiation of arbitrary functions exists out in the world when starting my next project

Thanks for posting!

> JAX can automatically differentiate native Python ... functions

sympy can differentiate functions, but they have to be set up properly. How can JAX differentiate native functions?

(Or do they mean numerical differentiation, like a finite difference estimation?)

Automatic Differentiation is an algorithm to efficiently compute the value of the derivative of a function implemented by some arbitrary code. It does not use numerical approximation; it combines algorithmic cleverness with a table of analytic derivatives for elementary functions. Despite reliance on analytic derivatives, it does not compute the analytic derivative of the function–it just computes the _value_ of the derivative for some particular input.

In order to differentiate native functions you have to accept a fairly loose definition of "derivative" that basically assumes you can ignore discontinuities in functions. For example, in the function below we can say that the derivative is piecewise continuous with a value of 0 for all x < 0, and value 1 for all x >= 0. AD extends this idea to "each output produced by a function follows a single path, so the corresponding derivative follows the same path".

``` def f(x): if x >= 0: return x else: return 0 ```

AD cannot magically determine the derivative of `lambda x: exp(x)` or similar–it needs a lookup table for elementary functions, for example: https://github.com/HIPS/autograd/blob/304552b1ba42e42bce97f0... However, AD does support differentiating through program flow control including function calls, loops, conditionals, etc. subject to the caveats above, which is much more difficult to do analytically.

The idea has now been generalized to all linear operators in functional programming https://www.youtube.com/watch?v=MmkNSsGAZhw&feature=youtu.be and the associated paper https://dl.acm.org/citation.cfm?doid=3243631.3236765

After checking the wiki link: Does it internally construct a kind of AST for the symbolic derivative, but only returns particular values, not the derivative itself? i.e. Since the new function returns derivative values, it seems it must itself be the derivative...

Looking at your github link, multiplication (product rule) doesn't seem to be handled there (only `def_linear(anp.multiply)`).

Would an implementation be something really straightforward, like:

    *(f(x),g(x)): f(x)*g'(x) + f'(x)*g(x)
i.e. it internally constructs an AST of the derivative, but only returns results at specific points, not the AST itself.

(Actually, the wiki eg for Forward Accumulation (https://wikipedia.org/wiki/Automatic_differentiation#Forward...) does include a product differentiated in this way, so I guess I got it right).

> multiplication (product rule) doesn't seem to be handled there (only `def_linear(anp.multiply)`)

def_linear is what handles the product rule. Other product operations (like anp.cross, anp.inner etc.) are implemented the same way. It's called "linear" because products are multi-linear functions, i.e. they are linear in each individual argument and you can get the derivative with respect to each parameter by simple substitution. (x+Δx)y - xy = Δxy. Together with the chain-rule for multi-parameter functions, the classic product rule falls out for free.

Neither, they mean automatic differentiation [1], which is not symbolic like SymPy or Mathematica, but also is not finite differences.

Often, the interface for doing this is to just extend an existing numerical or array type and overload all the arithmetic operations to keep track of the gradients. Then ordinary code for numerical computations will "just work". That's basically what they've done here, except with a sophisticated compiler.

1: https://en.wikipedia.org/wiki/Automatic_differentiation

How's AD different from simply calling the function with a slightly perturbed input and observing how much the output changes? I assume it's more efficient with multivariable functions because the aforementioned method requires one call per parameter?

It is indeed much more efficient. In general, you can evaluate a scalar function and its gradient with less than twice the effort required to compute the scalar function alone -- regardless of the number of parameters.

> How's AD different from simply calling the function with a slightly perturbed input and observing how much the output changes?

AD gives a more precise result, as AD calculates the value of the derivative at the point you want without any perturbed input. AD is also faster, as calculating the value of the derivative usually requires about the same number of elementary operations as calculating the value of the function. Whereas with finite difference you need to calculate the function twice.

What does it do, if the function is not differentiable? Most of the commonly seen functions in real code aren't.

How does it compare to CuPy (in terms of Numpy compatibility)?

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact