Hacker News new | past | comments | ask | show | jobs | submit login

There are two things going on. First, you're right that the vmap should have been done once outside the timing. But equally important, vmap(jit(...)) speed will generally be lower than jit(vmap(...)).

There are many reasons for this. First, the former will loop iterations on the GPU in a serial fashion. Second, the internal jit makes optimization options opaque to Jax. For example, if there's a loop of matmuls inside main, that loop can be converted to a loop of einsums if you vmap first. It can also be fused into sometimes into a bigger operation that doesn't jump control variables back and forth between CPU and GPU between time steps. Between the two you both increase throughput and decrease latency.

I think in Jax, jit(vmap(jit(...))) will also reoptimize the same way as jit(vmap(...)) but I'm not 100% certain.




On your last point, as long as you jit the topmost level, it doesn't matter whether or not you have inner jitted functions. The end result should be the same.

Source: https://github.com/google/jax/discussions/5199#discussioncom...


Confirmed in this collab notebook it doesn't make a tangible difference: https://colab.research.google.com/drive/1d7G-O5JX31lHbg7jTzz... .




Join us for AI Startup School this June 16-17 in San Francisco!

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: