Hacker News new | past | comments | ask | show | jobs | submit login
Foundations for Efficient and Expressive Differentiable Programming [pdf] (nips.cc)
80 points by espeed 3 months ago | hide | past | web | favorite | 14 comments

I saw this paper at nips and I thought it was an awesome application of continuation passing style. Typically in CPS one of the advantages is that you really only need a single stack frame (since the remainder of the program is encoded in the callback), but in this paper they allow the CPS’d program to use multiple stack frames and store the intermediate results from backprop in the callstack. It never would’ve occurred to me to remove one of the advantages of CPS in order to store values on the callstack. Cool idea!

Another paper on the same topic that got best paper at icfp this fall is http://conal.net/papers/essence-of-ad/

Which shows how to connect forward and reverse mode auto differentiation in a very elegant way

Co-author here. Happy to answer any questions, as usual ...

We also have a more recent and slightly longer draft with additional explanations and GPU training results for ResNet 50 and DeepSpeech2:

"Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator" https://www.cs.purdue.edu/homes/rompf/papers/wang-preprint20...

The Lantern framework is available here: https://github.com/feiwang3311/Lantern

Very cool.

Like all really good ideas, this one seems "obvious" in hindsight. I mean that is a compliment: It would have never occurred to me that transforming code into continuation-passing-style code would allow for automatic differentiation through all dynamic control-flow structures, by leveraging the function-call stack, thus eliminating the need for some kind of "tape" data structure, e.g., as in PyTorch.

My question is about the ongoing work to provide a JIT compiler for Python code. Do you expect it will provide full support for the entire PyTorch and/or Tensorflow APIs?

Thanks! Yes, it wasn't obvious at all when we started looking at AD either.

Lantern supports a good deal of PyTorch (via Snek, our Python front-end similar to AutoGraph) and can also read ONNX. Full feature parity is not our main goal--so far, supported features have been driven mostly by what is required for certain interesting models.

How does this effort compare to Myia https://github.com/mila-udem/myia ?

This looks interesting, thanks.

What you call "delimited continuations" sounds a bit like how AD works in Julia's Flux package (and maybe elsewhere): During the forward calculation, a chain of functions is constructed, whose evaluation is the backward pass. This is done by overloading the original * to return both x * y and a closure Δ -> (Δ * y, x * Δ).

Does that sound right, are these indeed similar, or have I mis-understood something? I have never read Scala, but if I squint I can make your overloading of * is doing something similar.

It's different -- instead of returning a closure, we take a closure as additional parameter (check paper for details). This means that the call stack stays intact and all intermediate values can be stack-allocated.

Could you please elaborate on what exactly enables stack allocation of intermediate values?

Is it a direct result of the semantics of delimited continuations? (perhaps because closures are passed as parameters and don't capture variables)

Or is it enabled via staging or optimizations?

It's specific to CPS form: every basic statement is a full function call, which moves to the next statement by calling another function (i.e. the continuation). Normally this would just be a very quick way to get yourself a stack overflow, so typical CPS-form compilers optimise (or require) the tail-call case, where you can pop the stack frame for the last instruction before moving on to the next. But for AD that's not a big deal, you just re-use those values in the backwards pass.

I doubt it's better than just allocating your own stack on the heap, especially when you have something like a differential equation solver with millions of statements, but it's still a neat trick.

Aha, that makes sense, thanks!

I agree that the performance benefits of such stack allocation (over heap allocation) aren't quite clear in practice.

I feel the bigger win of delimited cont./closure-based AD approaches is that they can model the control flow of reverse-mode AD without AD-specific code transformations. Delimited cont. is especially great at making things modular: each differentiable function performs primal computation, calls the callback with primal result, then performs adjoint computation.

OK, thanks, sounds worth trying to wrap my head around...

(Note BTW that the github page has a dead link to the paper.)

Thanks! Fixed now.

Applications are open for YC Summer 2019

Guidelines | FAQ | Support | API | Security | Lists | Bookmarklet | Legal | Apply to YC | Contact