Hacker News new | past | comments | ask | show | jobs | submit login
Show HN: How does JAX allocate memory on a TPU? An interactive C++ walkthrough (gist.github.com)
78 points by sillysaurusx 71 days ago | hide | past | favorite | 28 comments

One thing that's important to note about memory management with XLA, is that inside a compiled program there's no user-exposed "malloc"/"free". The memory usage and schedule of a given program/graph is statically optimized by the compiler (thus JAX requires static shapes when jitting). When running in op-by-op/eager mode the allocated buffers are coupled to the lifetime of the python array object, and are free'd when this array is garbage-collected.

>The memory usage and schedule of a given program/graph is statically optimized by the compiler

Is it really though? The only thing I see is


which traces back to


which doesn't anything smart that i can tell.

That first call is from tf2xla which JAX doesn't use. But I think those calls just set up the global buffers used by the launched program inputs/temporaries. The entire point of XLA's static shape discipline is allowing for a highly optimized memory schedule inside compiled programs.

I believe you but I'm wondering what that memory scheduling actually looks like (e.g. solving the full MIP?). do you have a code pointer to that scheduling routine?

What always kept me from trying jax is the following statement which is pretty prominent on the jax Github Readme

> This is a research project, not an official Google product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

Why doesn't Google push more on this in times where Tensorflow is falling behind in mind/market share pretty drastically?!

Because TF is the official one. Pushing Jax would mean giving up on TF, and I don’t think Google is ready to do that yet.

the writing is on the wall inside google: pathways is jax, not TF, DM uses Jax instead of TF most of the time, and most new researchers are not adopting TF. It's going to die a slow death like mapreduce.

Can't go far on HN these days before some hot take makes you roll your eyes. The number of Borg cells that run DM jobs is a rounding error relative to the number running TF jobs.

I used to work for Google on machine learning. It's clear where the trends were going. Also, DM has entire TPU pods allocated to it (like search and ads do).

At the risk of sounding like a HN stereotype, for those who don't know what a TPU is:

> TPUs are hardware accelerators specialized in deep learning tasks. They are supported in Tensorflow 2.1 both through the Keras high-level API, and at a lower level.

They're pretty rad. You can even get a few (dozen) for free: https://blog.gpt4.org/jaxtpu

Incidentally, that blog runs on a TPU. You can just SSH into them.

Stuff like that is why I think Jax will beat pytorch in the long run: https://blog.gpt4.org/mlmind

A couple years from now, Metabook might find themselves in real trouble. They won't be the React of ML anymore. All the researchers who want to get work done have already flocked to Jax, so I suspect it's a matter of time till the rest of the world notices.

Time will tell. For now, it's by far the most fun I've had in my entire career. I've been in ML for awhile now, but my background was gamedev, finance, and security -- I encourage you to dip your toe into the weird ML scene, because it happens to be a blast nowadays.

> They're pretty rad. You can even get a few (dozen) for free: https://blog.gpt4.org/jaxtpu > Incidentally, that blog runs on a TPU. You can just SSH into them.

As someone without machine learning background, I assumed a TPU is something like a GPU. Aren't they used for ML as well? So I'm surprised you can run linux userspace on it?

Oh, I read [1] too late. It's a (GNU/?)Linux VM with the special hardware already made available.

[1] https://news.ycombinator.com/item?id=29129554

> So I'm surprised you can run linux userspace on it?

For what it's worth, I was equally shocked! Felt like a miracle.

That special hardware also turns out to be miraculously-easy to use: https://github.com/tensorflow/tensorflow/blob/master/tensorf...

  // To compile: gcc -o libtpu_client libtpu_client.c -ldl
  // To run: sudo ./libtpu_client
I had to do a double-take, because compared to the uniquely hellacious experience of installing CUDA drivers on Ubuntu, this seemed to be... a single header file, and a single .c file.

Turns out, it is that easy: https://twitter.com/theshawwn/status/1400749405356052483

And not because there's some SDK preinstalled -- it's because if you unmask libtpu like a scooby-doo villain, you'll discover libtpu is LuaJIT in disguise. You can even do the equivalent of lua's loadstring() function: https://github.com/tensorflow/tensorflow/blob/dd60c07888b6e7...

It's just called TpuDriver_CompileProgramFromText instead of loadstring, and you have to write in a quirky assembly language.

But you don't need to write in that quirky assembly language, because you can just `import jax.jit`, jit a function, then dump it as HLO text. So you can just copy-paste it into your C file and run it :)

> compared to the uniquely hellacious experience of installing CUDA drivers on Ubuntu, this seemed to be... a single header file, and a single .c file.

> Turns out, it is that easy: https://twitter.com/theshawwn/status/1400749405356052483

> And not because there's some SDK preinstalled

Maybe not an SDK, but AFAICT it is that easy because all the drivers and stuff are already installed.

If you start with the drivers already installed, this is exactly how GPGPU works too. You can write an almost-identical program using OpenCL (or CUDA) instead of libtpu, the only difference is that instead of this HLO the program is in SPIR-V (or OpenCL C, PTX, CUDA C).

And you can start with all the drivers installed by using GCP or AMI images with all the GPU stuffs preinstalled.

It is awesome that this is accessible for free.

> that blog runs on a TPU. You can just SSH into them.

By “runs on a TPU,” you mean it runs on a CPU in a VM somewhere that also has access to TPU hardware, right?

If this is the new TPU phraseology, that’s pretty confusing IMO.

You’re right, I spoke carelessly. The proper term is that the blog is running on a “TPU VM”, which is exactly what you describe: a box with /dev/accel0 that libtpu.so uses to communicate directly with the TPU hardware.

The difference is, every TPU VM has 96 cpu cores and 350GB of RAM. (Pods only get 48 cores per host, but they have 1 host per 8 TPU cores, and the smallest pod has 32 TPU cores, for a whopping 1.2TB of RAM and 196 CPU cores.)

Which is to say, I still think of them as “a TPU”, because nowhere in the world have I ever been able to access that amount of raw horsepower. Not on x86_64 Ubuntu that you can pip install things on, at least. Like a blog. :)

Wanna see a magic trick? Clone tensorflow onto a TPU VM and start building it. htop will light up like a Christmas tree (https://twitter.com/theshawwn/status/1400771262901854214) and it'll finish in about 25 minutes flat, if I remember correctly.

So yeah, Google is throwing around compute like Israel doling out vacations to Tel Aviv for distant relatives: it’s totally free. Partake!

(I’m really looking forward to seeing Israel someday. I never realized how beautiful the beaches are...)

96 core TPU VMs, free as in beer! It’s so exciting that I just can’t shut up about it.

TRC gives you access to 100 separate VMs the moment you sign up.

Having access to 100 VM-yachts totally rules. SSHing into 100 of them feels like commanding a WW2 carrier division, or something.

It’s quite literally too fun: I have to force myself not to spend all day unlocking their secrets. There’s so much new territory to explore — every day feels like a tremendous adventure. Our computing ancestors could only dream of exploiting as much hardware as our M1’s take for granted, let alone one of these behemoths. Let alone one hundred!

I went back to look at my old notes. Christmas in 2019 was magical, because in January of 2020 I managed to fire up santa's 300 TPUs, while a colleague fired up santa's other 100 TPUs. Then we swarmed them together into a tornado of compute so powerful that even connecting to all 400 TPUs required special configuration settings ("too many open files" aka sockets): https://twitter.com/theshawwn/status/1221241517626445826

We were training models so fast that I bet even Goku in the hyperbolic time chamber would have a hard time training faster.

Thanks for the clarification.

It’s cool that google has a ton of money and can give people free access to lots of big servers, but it sounds like what you are excited about (lots of cpu cores, RAM, many VMs) seems to be mostly unrelated to the actual new TPU hardware, which is sorta disappointing.

PyTorch is working on catching up — I think they’ve already got some kind of “vmap” style function transformations in beta. And I’m sure they’ll figure out good higher order derivatives too. That’s like 90% of what people want out of Jax, so I think they’ll be able to compete.

The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.

I've found jax's debugging to be in different ways better and worse. The fact that the function transformations are traced is great. It means you can step debug in the tracing steps just as well as the actual eval steps, and you just have jaxpr.Tracers instead of jnp.ndarrays, or whatever. Outside of the transformations, it's just as easy to debug as numpy, which is a blessing. That's one of the biggest selling points.

Debugging jitted and pmapped code, on the other hand, is a pain. Since you can always step out of them to debug, it means that it's debugging performance issues that sucks. And boy does it suck. If anyone knows a good story for figuring out why my jitted thing is slow as hell on TPU, I'm all ears. The profiling section of the official docs is one of their weaker sections. (but big props to the overall documentation quality!)

> The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.

Hmm. Jax's ease of debugging was the very first thing that caught my attention: https://blog.gpt4.org/jaxtpu#:~:text=pdb.set_trace()

> I ran it on the TPU VM, saw the loss curve go down, and it was like an electric shock. "Wow! That actually... worked? Huh. that's weird. Things never work on the first try. I'm impressed."

> Then I plopped `import pdb; pdb.set_trace()` in the middle of the `loss` function and ran it again. It dropped me into the Python debugger.

> There was a tensor named `X_bt`. I typed `X_bt`. The debugger printed the value of `X_bt`.

> I was able to print out all the values of every variable, just like you'd expect Python to be able to do.

> There was a tensor named `Y_bt`. I typed `X_bt + Y_bt`. I was now staring at exactly what I expected: the sum of those two tensors.

> I could write `x + y`, or create new variables, or anything else I wanted.

> Now I was real impressed.

> If it sounds weird that I'm so easily impressed, it's because, you godda understand: until now, TPUs were a complete pain in the ass to use. I kept my feelings to myself, because I understood that the Cloud TPU team were working hard to improve TPUs, and the TFRC support team was wonderful, and I had so many TPUs to play with. But holy moly, if you were expecting any of the above examples to just work on the first try when using Tensorflow V1 on TPUs, you were in for a rude awakening. And if you thought "Well, Tensorflow v2 is supposedly a lot better, right? Surely I'll be able to do basic things without worrying...."

> ... no. Not even close. Not until Jax + TPU VMs.

In the subsequent year, it's been nothing but joy.

If the problem is that you want to see tensor values in a JIT'ed function, use a host callback. You can run actual Python wherever you want: https://jax.readthedocs.io/en/latest/jax.experimental.host_c...

> This module introduces the host callback functions call(), id_tap(), and id_print(), that send their arguments from the device to the host and invoke user-defined Python functions on the host, optionally returning results back to the device computation.

The nice part is, there's no "magic" under the hood. If you get a chance, I highly recommend reading through Autodidax: https://jax.readthedocs.io/en/latest/autodidax.html

Autodidax is a pure-python implementation of jax. (Literally in one file, on that page.) It walks you through how every aspect of jax works.

Delightfully, I found a secret branch where autodidax also implements host callbacks: https://github.com/google/jax/blob/effect-types/docs/autodid...

If you scroll to the very bottom of that file, you'll see an example of compiling your own XLA JIT'ed code which subsequently calls back into Python. TPUs do precisely the same thing.

Point being:

> PyTorch, for better or for worse, will actually run your Python code as you wrote it.

... is also true of jax, to within a rounding error less than "I personally don't mind writing id_print(x) instead of print(x)." :)

thanks, this is going to be very helpful for me. i guess it’s kind of like that old piece of advice, if you want some free Linux tech support, just post “Linux can’t do this but Windows can” :)

Plenty of free services like Google Colab exist for Tensorflow/Pytorch and offer gpu/tpu. I pay just $10 a month to train my models at my leisure on it, and the free version is great if you're doing shorter training runs.

Jax uses XLA (a IR that targets multiple platforms, including the Google TPU) as backend. Pytorch also has a XLA backend.

A terrible one. It’s literally one of the worst experiences you can have. I say that without an ounce of bias.

When tpu SSH first came out, I immediately went to my pytorch buddy (who became famous in the meantime by pioneering the CLIP AI art you’ve probably been seeing: https://mobile.twitter.com/rivershavewings) and said “Rivers, you have to try TPUs! They’re wonderful! Now you can run pytorch on them directly!”

She was skeptical, because she’d had nothing but endless problems with the official pytorch TPU notebooks (which use a TPU in “remote” mode, aka it fires up an RPC server you attach to).

“No no, trust me — I bet their rpc driver was just horrible. It can’t be true that the Pytorch XLA backend is terrible on TPUs. How could that be? Facebook makes awesome software, and they’ve worked on this for so long. Some intern probably wrote the rpc system or something. Anyway, you can SSH into mine! See, no RPC server! What’s your pubkey?”

A few hours later, she reported that it froze in exactly the same place.

I could hardly believe it. I was so upset that I dug really deeply to try to figure out what the heck the problem was.

It was recompiling itself. Every. Inference.

Every compile took 15 minutes.

Their XLA backend is so far behind Jax that it’s not even a contest.

Worse, I’ve since realized why they can’t fix the problem. JAX gives you precise control over which functions get JIT’ed (compiled into XLA), and when that happens.

Pytorch doesn’t. There’s no @torch.jit decorator for your functions.

That means they need to infer which parts to jit (the inference model), and which parts not to (the inference loop).

The moment that magic fails, congrats; it’s now recompiling every iteration. (15 min per iteration is impressive, since usually it’s measured in “iterations per second”…)

JAX even has a TPU compiler cache now, which persists across Python runs. It’s like MacOS in 2021 vs Windows in 1995.

This is deeply frustrating to me, because there’s no reason for it — I don’t know if Facebook is intentionally kneecapping pytorch to keep people off of TPUs, or if it’s just run of the mill incompetence. But (like Windows 95) they’re going to discover they’re not the only game in town forever.

The pytorch programming model is just really hard to adapt to an XLA-like compiler. Imperative python code doesn't translate to an ML graph compiler particularly well; Jax's API is functional, so it's easier to translate to the XLA API. By contrast, torch/xla uses "lazy tensors" that record the computation graph and compile when needed. The trouble is, if the compute graph changes from run to run, you end up recompiling a lot.

I guess in Jax you'd just only apply `jax.jit` to the parts where the compute graph is static? I'd be curious to see examples of how this works in practice. Fwiw, there's an offshoot of pytorch that is aiming to provide this sort of API (see https://github.com/pytorch/functorch and look at eager_compilation.py).

(Disclaimer: I worked on this until quite recently.)

Underscoring your point, GCP has a tutorial page for training Pytorch models on TPU: https://cloud.google.com/tpu/docs/tutorials/pytorch-pod

High-five fellow CLion user! I've found it exceedingly good at handling C++, C and Rust projects on GTK / Gnome.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact