Hacker News new | past | comments | ask | show | jobs | submit login
Translation and accelerated solving of differential equations on GPU platforms (arxiv.org)
112 points by UncleOxidant 11 months ago | hide | past | favorite | 32 comments



Uhh they time the vmap of the jit on Jax, basically skipping a ton of optimizations,.esp if there is any linear algebra in there. They also include the cost of building the vmap functional. Not a valid comparison.

https://github.com/utkarsh530/GPUODEBenchmarks/blob/ef807198...


This is very interesting. The claim does sound too good to be true. Am I understanding you correctly that you are saying including the vmap operation in the timing is wrong because it involves compilation time that could have been amortized over all the runs, and that the compilation time is considerable compared to the ode-solve itself?


There are two things going on. First, you're right that the vmap should have been done once outside the timing. But equally important, vmap(jit(...)) speed will generally be lower than jit(vmap(...)).

There are many reasons for this. First, the former will loop iterations on the GPU in a serial fashion. Second, the internal jit makes optimization options opaque to Jax. For example, if there's a loop of matmuls inside main, that loop can be converted to a loop of einsums if you vmap first. It can also be fused into sometimes into a bigger operation that doesn't jump control variables back and forth between CPU and GPU between time steps. Between the two you both increase throughput and decrease latency.

I think in Jax, jit(vmap(jit(...))) will also reoptimize the same way as jit(vmap(...)) but I'm not 100% certain.


On your last point, as long as you jit the topmost level, it doesn't matter whether or not you have inner jitted functions. The end result should be the same.

Source: https://github.com/google/jax/discussions/5199#discussioncom...


Confirmed in this collab notebook it doesn't make a tangible difference: https://colab.research.google.com/drive/1d7G-O5JX31lHbg7jTzz... .


Same for pytorch. I don't know enough pyt, but guessing they didn't jit anything.

https://github.com/utkarsh530/GPUODEBenchmarks/blob/ef807198...


What they should do is build the vmap and jit that, then run timing on calling the resulting function.


So instead of jitting main they should do something like

  @jax.jit
  @jax.vmap
  def main(...

?


This collab notebook shows effectively no difference from doing this: https://colab.research.google.com/drive/1d7G-O5JX31lHbg7jTzz....

The average for diffrax on this collab machine goes from 20.5 to something like 20.3 seconds. You can see DiffEqGPU.jl running from Python via diffeqpy at around 2.3 seconds. This is a very rough benchmark of course since using DiffEqGPU has a fairly high (documented) overhead, and the free tier T4 GPU is not great, but it shows the ballpark of an order of magnitude or so. Note that you can also see that compile times are pretty negligible even at this scale (and the paper benchmarks are a few orders of magnitude larger than this, so at that point it's really negligible).

That shouldn't be surprising though since we're talking about operations involving hundreds of thousands or millions of ODEs. At this scale, micro optimizations tend to have a much more minor effect. And the paper describes in detail that we developed two versions of the methods in Julia, one that was an array-based approach like Jax and PyTorch vmaps (EnsembleGPUArray), and another that was a kernel generating approach like MPGOS (EnsembleGPUKernel). Jax, PyTorch, and EnsembleGPUArray all performed similarly while MPGOS performed similarly to EnsembleGPUKernel. To us, this was a pretty strong indicator that the performance difference comes from the fact that the way EnsembleGPUKernel is performing the parallelism is very different from the approach that an ML library takes. And yes, there's small differences in the groups, but those are like 2x-3x or so, while the paper benchmarks are in log-scale because the difference between the two classes of designs are much larger.


Yes, that's right



Seems like the benchmarking code is a small script and what you are suggesting might be a few lines of code. Might be worthwhile to take a stab and see if there is a difference.


Submitters: "Please use the original title, unless it is misleading or linkbait; don't editorialize."

If you want to say what you think is important about an article, that's fine, but do it by adding a comment to the thread. Then your view will be on a level playing field with everyone else's: https://hn.algolia.com/?dateRange=all&page=0&prefix=false&so...

(Submitted title was "Julia GPU-based ODE solver 20x-100x faster than those in Jax and PyTorch". We've changed that to a shortened version of the paper title, to fit HN's 80 char limit.)


That rule seems problematic. You cite the rule, and then the last line of your comment explains that you've violated the rule.

As another example, I recently read a Factorio blog about how they do map generation, and there were a lot of technical details any aspiring game developer would be interested in, even if they don't play Factorio. The title of the blog post was "Maps 2.0" which would be meaningless as a HN title. Something like "How Factorio's procedural map generation works" would make more sense for HN, but would require breaking the rule. What should be done in this case?


It's problematic if you expect the rules to work like code; they don't. It's less problematic once you understand than HN has always been a spirit-of-the-law place, not a letter-of-the-law place (https://hn.algolia.com/?dateRange=all&page=0&prefix=false&qu...).

From that perspective it's easy to see how the submitted title was breaking the rule, and how shortening a title so as to fit HN's 80 char limit is not breaking the rule, as long as one doesn't shorten it in a misleading or linkbait way.

(Re your Factorio blog question, I'd have to see the particular article to answer that.)


"Instead of relying on high level array intrinsics that machine learning libraries use, it uses a direct kernel generation approach to greatly reduce the overhead." Chris Rackauckas on LinkedIn earlier today.


CR is a hero. The work he does on ODE and related libs in Julia is one of the selling pts of the language. He is the Alex Crichton of Julia.


What is direct kernel generation?


In this context I would imagine it's constructing source code for a kernel- the engine that implements a step in a neural network- that is closer to optimal. See https://cuda.juliagpu.org/stable/tutorials/performance/ for related work


https://cuda.juliagpu.org/stable/tutorials/performance/ provides various tips that someone who has written a kernel can use to speed it up, like using 32 bit integers and minimising runtime exceptions. Perhaps, I'm misunderstanding but it's not part of direct kernel generation, whatever that is.



Yeah you'd be surprised what the performance gain is for hand written kernels.

There's probably a ton left on the table if you really want to go fast.


In this case it's not tricks like those done in BLAS kernels. However, there are some intricacies about the algorithms that are chosen as noted in the paper. That said, most the difference here is simply from the difference in the high level details of how the parallelism is designed, not necessarily low level bit hacking optimizations. We leave those for another day.


The difference is really the level at which you are calling functions on the GPU. Say you have a function `f(x,y,z) = x .+ y .* sin.(z)`. If CUDA (simplify here, the paper does this for Intel OneAPI, Metal, IPUs, and AMD GPUs simultaneously but it's basically the same), then at some point you need to be calling some kernel function, a CUDA-compiled .ptx function which is then operated on over all of the inputs. One way to parallelize this is to have a set of primitive functions, `x .+ y`, `x .* y`, `sin.(x)`, and then decompose the execution into those kernels: first call sin, then call multiply, then call plus. The other way to do this is to on-demand build a specialized .ptx kernel for the function `f` and call that. Machine learning libraries do the former approach, but we demonstrate here that the latter is much better in this scenario because the call overhead to kernels is non-trivial and this ends up slowing down the process. If there's a tl;dr for the paper it's this, and then scale this approach to all GPU architectures from one codebase.

Now I'll simultaneously say that the choice machine learning libraries are making here is not stupid. You may look at this example and go "no duh call 1 kernel instead of 3", but you never want to over optimize. For the domain that ML libraries are designed for, these kernel calls are typically things like large matrix multiplications (that's the core of any deep neural network, with a few things around it). These kinds of operations are O(n^3) or O(n^2) on very large arrays. With that amount of compute to do on the memory, the overhead cost can go to nearly zero. Thus for the use case targeted by ML libraries, approaching the design of the GPU library as "just make enough optimized kernels" is a good design. For example, it was counted in 2021 that PyTorch had about 2,000 such kernels (https://dev-discuss.pytorch.org/t/where-do-the-2000-pytorch-...). Sit down, optimize the CUDA kernels, then make the high level code call the most appropriate one. That's a good design if the kernels are expensive enough, like in deep learning.

While Jax has a few other things going on, both the PyTorch and Jax vmap parallelism approach are effectively high level tools to shove larger arrays more nicely into such existing kernels. For example, one optimization that vmap does is fuse matrix-vector multiplications into matrix multiplications, i.e. Av1 + Av2 -> A*[v1;v2]. The purpose is to still use a small set of primitives and shove as big of array operations as you can into it.

However, that is not a good idea in all domains. In ODE solvers, you have lots of control flow and O(n) operations. This can make that "negligible" overhead very not negligible, and thus one needs to design the parallelism very differently in order to not run into the performance issues that one would hit with the "small kernel array based approach". The better approach in this domain (as demonstrated in the paper) is to build completely new kernels of the functions you're trying to compute, i.e. build a CUDA code and .ptx kernel for f directly, compile that, and do the one call. This has some downsides of course, as this kernel is effectively unable to be reused for other things, which then means that the you need to be able to do this kernel generation automatically for it to be useful at a package level.

In other words, domain-specific languages optimize to their respective domain of choice, but that may be leaving performance on the table for use cases outside of their directly targeted audience.


This... doesn't seem to do anything special? Everyone already knew it was bad to "batch" ODEs by making them bigger, e.g. in "Neural Ordinary Differential Equations" (the paper that introduced neural ODEs):

> One can still batch together evaluations through the ODE solver by concatenating the states of each batch element together, creating a combined ODE with dimension D × K. In some cases, controlling error on all batch elements together might require evaluating the combined system K times more often than if each system was solved individually. However, in practice the number of evaluations did not increase substantially when using minibatches.

I don't understand why someone wrote a 30-page, obfuscated paper on just... parallelizing it the obvious way.


I mean, it at least must not be obvious to the poster that says "the claim does sound too good to be true". But yes anyone with enough of an HPC background can look at how vmap is doing its parallelization and instantly know that ML frameworks like Jax and PyTorch are most likely losing an order of magnitude or two of performance. And of course we are very explicit in the paper that this is not novel because we show that the kernels that we are generating match the performance of MPGOS, which is a CUDA library which has the same architecture.

But of course, all of this discussion leaves off half of the title of the paper, "on Multiple GPU Platforms". The point is not that we are able to generate kernels which are doing the fast thing that a dedicated CUDA library does (i.e. not the slow thing that ML libraries are doing), rather the point is that we are doing this in a way where CUDA is not special. We generate similarly optimized kernels for AMD GPUs, Intel GPUs, and Apple silicon (Metal) using this approach. Mose also showed this same codebase can generate kernels for GraphCore IPUs without modifications too (see https://www.youtube.com/watch?v=-fxB0kmcCVE). Showing matching performance with good kernel codes was step 1 but portability (with a full feature set) is step 2. I'd be interested to know if you have any examples of ODE solvers which achieve this level of performance portability because we were unable to find one in the literature or open source.


Link to GitHub repo from the abstract: https://github.com/SciML/DiffEqGPU.jl


OT: I was hoping for a html version in light of:

https://blog.arxiv.org/2023/12/21/accessibility-update-arxiv...

I guess this wasn't uploaded in LaTeX?

Ed: Oh, this may be a date of submission thing:

> as long as papers were submitted on or after December 1st, 2023 and HTML conversion is successful


Replace the x with a 5 in the abstract URL.


https://ar5iv.org/abs/2304.06835

Thank you! I guess that's what:

> If you are familiar with ar5iv, an arXivLabs collaboration, our HTML offering is essentially bringing this impactful project fully “in-house”. Our ultimate goal is to backfill arXiv’s entire corpus so that every paper will have an HTML version, but for now this feature is reserved for new papers.

Refer to.


Anyone remember analog computers? They were really good at solving differential equations


> Titan Black

Huh. Is that another price/perf knee?

      Name       Cost  VRAM  TFlops64
    Titan Black $  99   6      1.882
    Titan V     $ 500  12      7.450 
    4090        $2000  24      1.290
Looks like it! Man, I'm glad AMD is at least trying now, this has gotten sad.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: