Is it really though? The only thing I see is
which traces back to
which doesn't anything smart that i can tell.
> This is a research project, not an official Google product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!
Why doesn't Google push more on this in times where Tensorflow is falling behind in mind/market share pretty drastically?!
> TPUs are hardware accelerators specialized in deep learning tasks. They are supported in Tensorflow 2.1 both through the Keras high-level API, and at a lower level.
Incidentally, that blog runs on a TPU. You can just SSH into them.
Stuff like that is why I think Jax will beat pytorch in the long run: https://blog.gpt4.org/mlmind
A couple years from now, Metabook might find themselves in real trouble. They won't be the React of ML anymore. All the researchers who want to get work done have already flocked to Jax, so I suspect it's a matter of time till the rest of the world notices.
Time will tell. For now, it's by far the most fun I've had in my entire career. I've been in ML for awhile now, but my background was gamedev, finance, and security -- I encourage you to dip your toe into the weird ML scene, because it happens to be a blast nowadays.
As someone without machine learning background, I assumed a TPU is something like a GPU. Aren't they used for ML as well? So I'm surprised you can run linux userspace on it?
For what it's worth, I was equally shocked! Felt like a miracle.
That special hardware also turns out to be miraculously-easy to use: https://github.com/tensorflow/tensorflow/blob/master/tensorf...
// To compile: gcc -o libtpu_client libtpu_client.c -ldl
// To run: sudo ./libtpu_client
Turns out, it is that easy: https://twitter.com/theshawwn/status/1400749405356052483
And not because there's some SDK preinstalled -- it's because if you unmask libtpu like a scooby-doo villain, you'll discover libtpu is LuaJIT in disguise. You can even do the equivalent of lua's loadstring() function: https://github.com/tensorflow/tensorflow/blob/dd60c07888b6e7...
It's just called TpuDriver_CompileProgramFromText instead of loadstring, and you have to write in a quirky assembly language.
But you don't need to write in that quirky assembly language, because you can just `import jax.jit`, jit a function, then dump it as HLO text. So you can just copy-paste it into your C file and run it :)
> Turns out, it is that easy: https://twitter.com/theshawwn/status/1400749405356052483
> And not because there's some SDK preinstalled
Maybe not an SDK, but AFAICT it is that easy because all the drivers and stuff are already installed.
If you start with the drivers already installed, this is exactly how GPGPU works too. You can write an almost-identical program using OpenCL (or CUDA) instead of libtpu, the only difference is that instead of this HLO the program is in SPIR-V (or OpenCL C, PTX, CUDA C).
And you can start with all the drivers installed by using GCP or AMI images with all the GPU stuffs preinstalled.
It is awesome that this is accessible for free.
By “runs on a TPU,” you mean it runs on a CPU in a VM somewhere that also has access to TPU hardware, right?
If this is the new TPU phraseology, that’s pretty confusing IMO.
The difference is, every TPU VM has 96 cpu cores and 350GB of RAM. (Pods only get 48 cores per host, but they have 1 host per 8 TPU cores, and the smallest pod has 32 TPU cores, for a whopping 1.2TB of RAM and 196 CPU cores.)
Which is to say, I still think of them as “a TPU”, because nowhere in the world have I ever been able to access that amount of raw horsepower. Not on x86_64 Ubuntu that you can pip install things on, at least. Like a blog. :)
Wanna see a magic trick? Clone tensorflow onto a TPU VM and start building it. htop will light up like a Christmas tree (https://twitter.com/theshawwn/status/1400771262901854214) and it'll finish in about 25 minutes flat, if I remember correctly.
So yeah, Google is throwing around compute like Israel doling out vacations to Tel Aviv for distant relatives: it’s totally free. Partake!
(I’m really looking forward to seeing Israel someday. I never realized how beautiful the beaches are...)
96 core TPU VMs, free as in beer! It’s so exciting that I just can’t shut up about it.
TRC gives you access to 100 separate VMs the moment you sign up.
Having access to 100 VM-yachts totally rules. SSHing into 100 of them feels like commanding a WW2 carrier division, or something.
It’s quite literally too fun: I have to force myself not to spend all day unlocking their secrets. There’s so much new territory to explore — every day feels like a tremendous adventure. Our computing ancestors could only dream of exploiting as much hardware as our M1’s take for granted, let alone one of these behemoths. Let alone one hundred!
I went back to look at my old notes. Christmas in 2019 was magical, because in January of 2020 I managed to fire up santa's 300 TPUs, while a colleague fired up santa's other 100 TPUs. Then we swarmed them together into a tornado of compute so powerful that even connecting to all 400 TPUs required special configuration settings ("too many open files" aka sockets): https://twitter.com/theshawwn/status/1221241517626445826
We were training models so fast that I bet even Goku in the hyperbolic time chamber would have a hard time training faster.
It’s cool that google has a ton of money and can give people free access to lots of big servers, but it sounds like what you are excited about (lots of cpu cores, RAM, many VMs) seems to be mostly unrelated to the actual new TPU hardware, which is sorta disappointing.
The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.
Debugging jitted and pmapped code, on the other hand, is a pain. Since you can always step out of them to debug, it means that it's debugging performance issues that sucks. And boy does it suck. If anyone knows a good story for figuring out why my jitted thing is slow as hell on TPU, I'm all ears. The profiling section of the official docs is one of their weaker sections. (but big props to the overall documentation quality!)
Hmm. Jax's ease of debugging was the very first thing that caught my attention: https://blog.gpt4.org/jaxtpu#:~:text=pdb.set_trace()
> I ran it on the TPU VM, saw the loss curve go down, and it was like an electric shock. "Wow! That actually... worked? Huh. that's weird. Things never work on the first try. I'm impressed."
> Then I plopped `import pdb; pdb.set_trace()` in the middle of the `loss` function and ran it again. It dropped me into the Python debugger.
> There was a tensor named `X_bt`. I typed `X_bt`. The debugger printed the value of `X_bt`.
> I was able to print out all the values of every variable, just like you'd expect Python to be able to do.
> There was a tensor named `Y_bt`. I typed `X_bt + Y_bt`. I was now staring at exactly what I expected: the sum of those two tensors.
> I could write `x + y`, or create new variables, or anything else I wanted.
> Now I was real impressed.
> If it sounds weird that I'm so easily impressed, it's because, you godda understand: until now, TPUs were a complete pain in the ass to use. I kept my feelings to myself, because I understood that the Cloud TPU team were working hard to improve TPUs, and the TFRC support team was wonderful, and I had so many TPUs to play with. But holy moly, if you were expecting any of the above examples to just work on the first try when using Tensorflow V1 on TPUs, you were in for a rude awakening. And if you thought "Well, Tensorflow v2 is supposedly a lot better, right? Surely I'll be able to do basic things without worrying...."
> ... no. Not even close. Not until Jax + TPU VMs.
In the subsequent year, it's been nothing but joy.
If the problem is that you want to see tensor values in a JIT'ed function, use a host callback. You can run actual Python wherever you want: https://jax.readthedocs.io/en/latest/jax.experimental.host_c...
> This module introduces the host callback functions call(), id_tap(), and id_print(), that send their arguments from the device to the host and invoke user-defined Python functions on the host, optionally returning results back to the device computation.
The nice part is, there's no "magic" under the hood. If you get a chance, I highly recommend reading through Autodidax: https://jax.readthedocs.io/en/latest/autodidax.html
Autodidax is a pure-python implementation of jax. (Literally in one file, on that page.) It walks you through how every aspect of jax works.
Delightfully, I found a secret branch where autodidax also implements host callbacks: https://github.com/google/jax/blob/effect-types/docs/autodid...
If you scroll to the very bottom of that file, you'll see an example of compiling your own XLA JIT'ed code which subsequently calls back into Python. TPUs do precisely the same thing.
> PyTorch, for better or for worse, will actually run your Python code as you wrote it.
... is also true of jax, to within a rounding error less than "I personally don't mind writing id_print(x) instead of print(x)." :)
When tpu SSH first came out, I immediately went to my pytorch buddy (who became famous in the meantime by pioneering the CLIP AI art you’ve probably been seeing: https://mobile.twitter.com/rivershavewings) and said “Rivers, you have to try TPUs! They’re wonderful! Now you can run pytorch on them directly!”
She was skeptical, because she’d had nothing but endless problems with the official pytorch TPU notebooks (which use a TPU in “remote” mode, aka it fires up an RPC server you attach to).
“No no, trust me — I bet their rpc driver was just horrible. It can’t be true that the Pytorch XLA backend is terrible on TPUs. How could that be? Facebook makes awesome software, and they’ve worked on this for so long. Some intern probably wrote the rpc system or something. Anyway, you can SSH into mine! See, no RPC server! What’s your pubkey?”
A few hours later, she reported that it froze in exactly the same place.
I could hardly believe it. I was so upset that I dug really deeply to try to figure out what the heck the problem was.
It was recompiling itself. Every. Inference.
Every compile took 15 minutes.
Their XLA backend is so far behind Jax that it’s not even a contest.
Worse, I’ve since realized why they can’t fix the problem. JAX gives you precise control over which functions get JIT’ed (compiled into XLA), and when that happens.
Pytorch doesn’t. There’s no @torch.jit decorator for your functions.
That means they need to infer which parts to jit (the inference model), and which parts not to (the inference loop).
The moment that magic fails, congrats; it’s now recompiling every iteration. (15 min per iteration is impressive, since usually it’s measured in “iterations per second”…)
JAX even has a TPU compiler cache now, which persists across Python runs. It’s like MacOS in 2021 vs Windows in 1995.
This is deeply frustrating to me, because there’s no reason for it — I don’t know if Facebook is intentionally kneecapping pytorch to keep people off of TPUs, or if it’s just run of the mill incompetence. But (like Windows 95) they’re going to discover they’re not the only game in town forever.
I guess in Jax you'd just only apply `jax.jit` to the parts where the compute graph is static? I'd be curious to see examples of how this works in practice. Fwiw, there's an offshoot of pytorch that is aiming to provide this sort of API (see https://github.com/pytorch/functorch and look at eager_compilation.py).
(Disclaimer: I worked on this until quite recently.)