Which shows how to connect forward and reverse mode auto differentiation in a very elegant way
"Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator"
The Lantern framework is available here:
Like all really good ideas, this one seems "obvious" in hindsight. I mean that is a compliment: It would have never occurred to me that transforming code into continuation-passing-style code would allow for automatic differentiation through all dynamic control-flow structures, by leveraging the function-call stack, thus eliminating the need for some kind of "tape" data structure, e.g., as in PyTorch.
My question is about the ongoing work to provide a JIT compiler for Python code. Do you expect it will provide full support for the entire PyTorch and/or Tensorflow APIs?
Lantern supports a good deal of PyTorch (via Snek, our Python front-end similar to AutoGraph) and can also read ONNX. Full feature parity is not our main goal--so far, supported features have been driven mostly by what is required for certain interesting models.
What you call "delimited continuations" sounds a bit like how AD works in Julia's Flux package (and maybe elsewhere): During the forward calculation, a chain of functions is constructed, whose evaluation is the backward pass. This is done by overloading the original * to return both x * y and a closure Δ -> (Δ * y, x * Δ).
Does that sound right, are these indeed similar, or have I mis-understood something? I have never read Scala, but if I squint I can make your overloading of * is doing something similar.
Is it a direct result of the semantics of delimited continuations? (perhaps because closures are passed as parameters and don't capture variables)
Or is it enabled via staging or optimizations?
I doubt it's better than just allocating your own stack on the heap, especially when you have something like a differential equation solver with millions of statements, but it's still a neat trick.
I agree that the performance benefits of such stack allocation (over heap allocation) aren't quite clear in practice.
I feel the bigger win of delimited cont./closure-based AD approaches is that they can model the control flow of reverse-mode AD without AD-specific code transformations. Delimited cont. is especially great at making things modular: each differentiable function performs primal computation, calls the callback with primal result, then performs adjoint computation.
(Note BTW that the github page has a dead link to the paper.)