
Using JAX, numpy, and optimization techniques to improve separable image filters - bartwr
https://bartwronski.com/2020/03/15/using-jax-numpy-and-optimization-techniques-to-improve-separable-image-filters/
======
jphoward
Is it fair to say JAX is what Google made when they looked at Pytorch/Autograd
and thought, "oh damn, that's what we should have done?".

If so, is this the beginning of the end of Tensorflow? I know Tensorflow is
still top for production, but it is certainly rapidly losing followings in the
research field, and Pytorch and now starting to focus on deployment as they
know this is their weakness.

~~~
kragen
It sounds like JAX is necessarily storing your whole calculation in memory, so
it will necessarily use more memory for automatic differentiation of heavily
iterative calculations, while other implementations of backward-mode automatic
differentiation can instead restart your calculation from checkpoints to avoid
storing the whole thing. This could be an advantage of several orders of
magnitude for some calculations: using twice the CPU or GPU time in exchange
for one thousandth or one ten-thousandth of the memory.

~~~
mattjjatgoogle
Rematerialization in autodiff is super interesting! XLA does rematerialization
optimizations, so you get those automatically under jax.jit. There's also the
jax.checkpoint decorator
([https://github.com/google/jax/pull/1749](https://github.com/google/jax/pull/1749))
which lets you control reverse-mode checkpointing yourself; you can use it
recursively to implement sophisticated checkpointing strategies (see Example 5
in that PR, which is the classic strategy for getting memory cost to scale
like log(N) for iteration count N but requiring log(N) times as much
computational work). It'd be interesting to experiment with heuristics for
deploying those strategies automatically (e.g. given a program in JAX's jaxpr
IR) but one of JAX's core philosophies is to keep things explicit and give
users control through composable APIs. Automatic heuristics can be built on
top.

Another goal is to make JAX a great system for playing with things like this!

~~~
kragen
Thank you for the correction! I should have checked out the software before
posting my incorrect surmises from the blog post. It sounds awesome!

~~~
mattjjatgoogle
No worries! I didn't mean it as a correction so much as just a discussion; I'm
sure it's true that other autodiff systems have very sophisticated automatic
remat (like
[https://openreview.net/forum?id=BkYYXJ9i-](https://openreview.net/forum?id=BkYYXJ9i-)).
I'm hoping as users push JAX on new applications, especially in simulation and
scientific computing, we'll learn a lot and be able to improve!

There's also "cross-country optimization" ([https://www-
sop.inria.fr/tropics/slides/EdfCea05.pdf](https://www-
sop.inria.fr/tropics/slides/EdfCea05.pdf)) for mixing some forward-mode into
reverse-mode to improve memory efficiency. Analogously to jax.checkpoint,
we've only experimented with exposing that manually (in jax.jarrett, named
because of
[https://arxiv.org/abs/1810.08297](https://arxiv.org/abs/1810.08297)), and
even then only for a special case. There's a lot to learn about, experiment
with, and build!

------
kragen
This is a very interesting article, with much better illustrations and deeper
investigation than the note on separable image filters I wrote in Dercuano.
And I didn't know about JAX, and it's very valuable to know about it now. But
the article does have some errors.

The article says:

> _Optimization of arbitrary functions is generally a NP-hard problem (there
> are no solutions other than exploring every possible value, which is
> impossible in the case of continuous functions)_

It is true that optimization of arbitrary functions, or even many interesting
classes of functions, is NP-hard. However, the definition given of NP-hard is
incorrect, and in fact, on modern hardware, existing SMT solvers such as Z3
can solve substantial instances of many interesting NP-hard optimization
problems, precisely because they do _not_ explore every possible value.
Moreover, it is in general possible (but again NP-hard) to use interval
arithmetic to rigorously optimize functions on continuous domains (which seems
to be what is meant), as long as they are not too _dis_ continuous; the answer
you get is only an approximation of the true optimum, but you can calculate it
to any desired precision.

One particularly interesting class of optimization problems — because they are
_not_ NP-hard — are continuous _linear_ optimization problems, which can be
solved in guaranteed polynomial time using interior-point methods or usually
in polynomial time using the "simplex method". Contrary to what you'd think
from the quote from the article, going from continuous to discrete makes the
problem NP-hard again. There is also a note in Dercuano surveying the
landscape of existing software and methods for solving linear optimization
problems; there's a lot of very powerful stuff out there.

It turns out that you can efficiently solve an enormous range of practical
optimization problems by introducing a small number of discrete variables into
a linear optimization problem, thus gaining most of the performance benefit of
using a linear optimizer. I don't know if there's a way to get a reasonable
perceptual result in a case like this with a linear optimizer, though.

> _This is where various auto-differentiation libraries can help us. Given
> some function, we can compute its gradient / derivative with regards to some
> variables completely automatically! This can be achieved either
> symbolically, or in some cases even numerically if closed-form gradient
> would be impossible to compute._

Automatic differentiation is a specific approach to differentiation which is
an _alternative_ to symbolic differentiation and the older kind of numerical
differentiation. What JAX does is automatic differentiation.

