The inferencing phase of a neural network attempts to minimize error or loss as defined by the user. This is done by iteratively applying gradient descent to the error function. Thus, the error function must have a known derivative, which can be difficult if the function has loops and conditionals.
Autograd is a software package that produces derivatives of a function automatically. It creates a computation graph of the user-defined code from which it can determine the derivative:
XLA is a JIT from TensorFlow that compiles common array functions. The JAX project from this GitHub page brings JIT optimizations to Autograd's automatic differentiation. That will speed-up the error function when inferring a neural network. Neat!
(I would be grateful for any corrections to my above explanation as I am not an expert in ML.)
Inferencing means “to predict” (I’m not sure when this terminology became popular; a few years ago most of us were just using the word predict)
Once trained, a model no longer requires derivatives. It’s more or less a function evaluation, which can be done on plain CPUs.
Also, shouldn't the verb be inferring?
Yes, I believe the verb is "inferring".
I've been using JAX for a while now. A paper I'm an author on (https://arxiv.org/abs/1806.09597, a follow-up to https://news.ycombinator.com/item?id=18633215) resulted in an algorithm that required taking second-derivatives on a per-example basis. This is extremely difficult in TensorFlow, but with JAX it was a 2-liner. Even better, it's _super_ fast, thanks to XLA's compile-to-GPU and JAX's auto-batching mechanics.
I highly recommend JAX to power users. It's nowhere near as feature-complete from a neural network sense as, say, PyTorch, but it is very good at what it does, and its core developers are second to none in responsiveness.
About this time last year I had an optimization problem involving linear combinations of Voronoi cell centroids and I basically used random search (I can't remember why differential evolution didn't work).
I don't know much, say, Keras, but I think I understand how to implement backprop-like behaviors with Autograd. Maybe even sandwich concept-specific models between fully-connected linear+sigmoid layers to give them some oomph.
1) Take a closer look at tf.hessian()'s implementation. Notice the tf.while() loop, calculating tf.gradient() for each coordinate of the first gradient? That doesn't happen in parallel! Plus, you need to know about TensorArrays!
2) Per-example gradients means an outside tf.while() loop over each example. Another linear slow down!
The difficulty I refer to is in wrangling conditionals in TF, and the trickiness of obtaining partial derivatives (rather than total derivatives) which JAX makes trivial.
There may be a
Seems a little limited in terms of supported operations for autograd compared to Chainer:https://chainer.org or Flux:http://fluxml.ai
Really cool to see though! XLA/TPU support is awesome.
I have some academic software (https://github.com/popgenmethods/momi2) that uses autograd, was planning to port it to pytorch since it's better supported/maintained, but now I'll have to consider jax. Though I'm a little worried about the maturity of the project, seems like the numpy/scipy coverage is not all the way there yet. Then again, it would be fun to contribute back to JAX, I did contribute a couple PRs to autograd back in the day so I think I could jump right into it...
Definitely will try to keep in mind that a tool for fast differentiation of arbitrary functions exists out in the world when starting my next project
Thanks for posting!
sympy can differentiate functions, but they have to be set up properly. How can JAX differentiate native functions?
(Or do they mean numerical differentiation, like a finite difference estimation?)
In order to differentiate native functions you have to accept a fairly loose definition of "derivative" that basically assumes you can ignore discontinuities in functions. For example, in the function below we can say that the derivative is piecewise continuous with a value of 0 for all x < 0, and value 1 for all x >= 0. AD extends this idea to "each output produced by a function follows a single path, so the corresponding derivative follows the same path".
if x >= 0:
AD cannot magically determine the derivative of `lambda x: exp(x)` or similar–it needs a lookup table for elementary functions, for example: https://github.com/HIPS/autograd/blob/304552b1ba42e42bce97f0... However, AD does support differentiating through program flow control including function calls, loops, conditionals, etc. subject to the caveats above, which is much more difficult to do analytically.
The idea has now been generalized to all linear operators in functional programming https://www.youtube.com/watch?v=MmkNSsGAZhw&feature=youtu.be and the associated paper https://dl.acm.org/citation.cfm?doid=3243631.3236765
Looking at your github link, multiplication (product rule) doesn't seem to be handled there (only `def_linear(anp.multiply)`).
Would an implementation be something really straightforward, like:
*(f(x),g(x)): f(x)*g'(x) + f'(x)*g(x)
(Actually, the wiki eg for Forward Accumulation (https://wikipedia.org/wiki/Automatic_differentiation#Forward...) does include a product differentiated in this way, so I guess I got it right).
def_linear is what handles the product rule. Other product operations (like anp.cross, anp.inner etc.) are implemented the same way. It's called "linear" because products are multi-linear functions, i.e. they are linear in each individual argument and you can get the derivative with respect to each parameter by simple substitution. (x+Δx)y - xy = Δxy. Together with the chain-rule for multi-parameter functions, the classic product rule falls out for free.
Often, the interface for doing this is to just extend an existing numerical or array type and overload all the arithmetic operations to keep track of the gradients. Then ordinary code for numerical computations will "just work". That's basically what they've done here, except with a sophisticated compiler.
AD gives a more precise result, as AD calculates the value of the derivative at the point you want without any perturbed input. AD is also faster, as calculating the value of the derivative usually requires about the same number of elementary operations as calculating the value of the function. Whereas with finite difference you need to calculate the function twice.