Hacker News new | past | comments | ask | show | jobs | submit login
Show HN: Tune LLaMa3.1 on Google Cloud TPUs (github.com/felafax)
189 points by felarof 22 days ago | hide | past | favorite | 52 comments
Hey HN, we wanted to share our repo where we fine-tuned Llama 3.1 on Google TPUs. We’re building AI infra to fine-tune and serve LLMs on non-NVIDIA GPUs (TPUs, Trainium, AMD GPUs).

The problem: Right now, 90% of LLM workloads run on NVIDIA GPUs, but there are equally powerful and more cost-effective alternatives out there. For example, training and serving Llama 3.1 on Google TPUs is about 30% cheaper than NVIDIA GPUs.

But developer tooling for non-NVIDIA chipsets is lacking. We felt this pain ourselves. We initially tried using PyTorch XLA to train Llama 3.1 on TPUs, but it was rough: xla integration with pytorch is clunky, missing libraries (bitsandbytes didn't work), and cryptic HuggingFace errors.

We then took a different route and translated Llama 3.1 from PyTorch to JAX. Now, it’s running smoothly on TPUs! We still have challenges ahead, there is no good LoRA library in JAX, but this feels like the right path forward.

Here's a demo (https://dub.sh/felafax-demo) of our managed solution.

Would love your thoughts on our repo and vision as we keep chugging along!




I'm pretty sure anyone finetuning Lllama now on a regular basis is using https://github.com/unslothai/unsloth so comparisons should be against that. The open source version is ~2x faster than default implementations. NVidia only, although the kernels are in Triton so might be portable.


I remember seeing them on HN when the first started! I never understood what’s the price you pay, how did they get such a big speed up and less memory usage?


There's previous comments, apparently the founder did a lot of math re-deriving things from scratch :)

https://news.ycombinator.com/item?id=39672070

https://unsloth.ai/blog/gemma-bugs


nice work in gemma-bugs -- compared to plenty of research work that is a km deep in real math, this tech note is a just few python tweaks. But finding those and doing it? apparently this is useful and they did it. Easy to read (almost child-like) writeup.. thx for pointing to this.


They main author used to worth Nvidia. There's a free plan, and you can pay to get multiple GPU support.


Indeed, a lora finetune of llama 3.1 8B works on a single 24GB GPU and takes from a few hours to a few days depending on the dataset size.


Very cool! Unlocking TPU training is a big win.

FWIW, if this helps prioritize: personally I'd find LoRA training for Llama 3.1 most useful (which it sounds like currently isn't well-supported with Felafax?) since with something like vLLM you can serve large numbers of LoRAs that share the same underlying GPU resources (assuming they're based on the same base model), vs full finetunes where each model will need to deploy on its own set of GPUs. In general I would guess that full finetunes are going to be less cost effective for most enterprise use cases: finetuning — whether full-finetuning or PEFT — generally improves only task-specific performance, so assuming you've got more than one task you want to use a model for in your business, it'll pretty quickly become dramatically cheaper to do the tasks with LoRAs rather than full finetunes unless you're saturating the boxes for each specific task. So, I'm hoping you guys build support for LoRA training with JAX in addition to finetuning!


Thanks for the detailed feedback! Yes, supporting LoRA fine-tuning is one of the things we are already working on.

btw, we have LoRA supported with Llama3 PyTorch-XLA model. Check that out in meanwhile.


I am actually not surprised by JAX converting better to XLA. Also deep respect for anybody in this space as their is lot of complexity (?) to deal with at the framework and compiler level.


Thank you! Yeah, there are a few complexities and very little documentation around JAX, plus a lot of missing libraries.


I'm totally new to AI. If I take for example LLaMa 3.1 (small size 8B), what's the rough budget to fine tune it against for example 1GB of extra text data, in any cloud GPU service? (if compute time is not a problem, I can wait)


Let's assume that the average token size in your 1GB file is 4 characters (which is the average that the OpenAI tokenizer generally will get; I assume the Llama tokenizer is similar). 4 chars is 4 bytes, assuming here that you're using UTF-8 and your characters are in the Latin range, so that means your training data is about 264MM tokens.

Let's assume you're doing a single-epoch LoRA training run. A single H100 should be enough to train Llama 3.1 8B, and it should crank through 264MM tokens in a couple hours, IMO. Since you're not doing multi-GPU training, a PCIe H100 should be fine — you don't need the slightly pricier SXM H100s — and the PCIe versions go for about $2.50/hr on Runpod.

So, about $5 for a custom model, that's probably the best in the world at whatever your task is! (Even if it might be a little dumber at other tasks.) Insanely cheap when you think about it.

TPUs won't beat H100s on price for on-demand personal use cases, but for reserved capacity (i.e. businesses) they're slightly cheaper.


I'm still new to LoRA/fine tunes, but: I can't just dump in 1gb of data, correct? I need to structure it in Question/Answer or the like?

So it would seem the cost really becomes converting/curating the data into a usable format first.


You can dump in 1gb of data (Unsloth supports "raw text training") but whether you'd get good results or a useless model is a different issue. I doubt you'd get a good result unless you combine that with question/answer training as well, assuming that feature is even useful at all for your scenario.


Really incredible :O I was imagining numbers with two extra zeros


Do you have any apples-to-apples speed and cost comparisons across Nvidia vs. non-NVIDIA chips (as you mentioned: TPUs, Trainium, AMD GPUs)?


Google published this benchmark a year or so ago comparing TPU vs NVIDIA (https://github.com/GoogleCloudPlatform/vertex-ai-samples/blo...)

Conclusion is at the bottom, but TLDR was TPUs were 33% cheaper (performance per dollar) and JAX scales very well compared to PyTorch.

If you are curious, there was a thorough comparison done by Cohere and they published their paper https://arxiv.org/pdf/2309.07181 -- TPU+JAX turned out to be more performant and more fault tolerant (less weird errors).


> For example, training and serving Llama 3.1 on Google TPUs is about 30% cheaper than NVIDIA GPUs

When you say this, you should specify which Nvidia GPU you mean (I assume h100 SXM) and that price you are assuming for such GPU.

One can't simply compare based on the on demand price on GCP, because the Nvidia GPUs there are extremely overpriced.


Runpod charges $3.49/hr for an H100 SXM, which is fairly cheap as far as on-demand H100s go. A v5p TPU is $4.20/hr, but has 95GB RAM instead of 80GB on the H100 — so you'll need fewer TPUs to get the same amount of RAM.

Runpod is ever-so-slightly cheaper than Google TPUs on-demand on a per-GB basis: about 4.3 cents an hour per GB for Runpod vs 4.4 cents an hour per GB for a TPU. But let's look at how they compare with reserved pricing. Runpod is $2.79/hr with a 3-month commitment (the longest commitment period they offer), whereas Google offers v5p TPUs for $2.94/hr for a 1-year commitment (the shortest period they offer; and to be honest, you probably don't want to make 3-year commitments in this space, since there are large perf gains in successive generations).

If you're willing to do reserved capacity, Google is cheaper than Runpod per GB of RAM you need to run training or inference: Runpod is about 3.4 cents per GB per hour vs Google for about 3.09 cents per GB per hour. Additionally, Google presumably has a lot more TPU capacity than Runpod has GPU capacity, and doing multi-node training is a pain with GPUs and less so with TPUs.

Another cheap option to benchmark against is Lambda Labs. Now, Lambda is pretty slow to boot, and considerably more annoying to work with (e.g. they only offer preconfigured VMs, so you'll need to do some kind of management on top of them). They offer H100s for $2.99/hr "on-demand" (although in my experience, prepare to wait 20+ minutes for the machines to boot); if cold boot times don't matter to you, they're even better than Runpod if you need large machines (they only offer 8xH100 nodes, though: nothing smaller). For a 1-year commit, they'll drop prices to $2.49/hr... Which is still more expensive on a per-GB basis than TPUs — 3.11 cents per GB per hour vs 3.09 cents per GB per hour — and again I'd trust Google's TPU capacity more than Lambda's H100 capacity.

It's not dramatically cheaper than the cheapest GPU options available, but it is cheaper if you're working with reserved capacity — and probably more reliably available in large quantities.


Thank you for the detailed analysis. We need to spend some time thinking and coming up with a price comparison like this. We’ll use this as inspiration!


VRAM per GPU isn't such an interesting metric. If it was, everyone would be fine tuning on A100 80gb :)

What matters is steps per $ and to some degree also speed (I'm happy to pay premium sometimes to get the fine tuning results faster).


True, but a TPU v5p is supposedly much closer to an H100 than an A100 (the A100 and TPU v4 were fairly similar) — and you need the RAM as a baseline just to fit the model. I haven't seen super thorough benchmarking done between the two but the Google claims similar numbers. So, $/RAM/hr is all I can really look at without benchmarking sadly.


GCP is one of the cheapest places you can get them at scale.


Wouldn't really say it's the cheapest option...there are other providers like Lambda Labs or Ori.co where you can find them way cheaper


Tell me more.

At what scale were you able to get a significant discount and how much?

Most people will be (full) fine tuning on 8xh100 or 16xh100 for few days at a time.


What was the estimate for how much time you guys took to translate the torch to Jax vs how much you spent on XLA?


It took roughly 2-3 weeks to translate Torch to JAX, but I had past experience writing JAX from my time at Google.

We spent nearly 4 weeks getting PyTorch XLA working on TPU. Hope that answers your question!


Anyone want to comment on this versus the fine tune speedups from llama3.1 with unsloth?


Unsloth is great! They focus on single-GPU and LoRA fine-tuning on NVIDIA GPUs. We are initially trying to target multi-node, multi-TPU, full-precision training use cases.

That said, in terms of single-GPU speed, we believe we would be behind but not too far off, thanks to JAX+TPU's more performant stack. Additionally, we can do larger-scale multi-node training on TPUs.

There are still more optimizations we need to do for Llama 3.1, such as adding Pallas memory attention kernels, etc


Where in the codebase is the logic specific to TPU vs. CUDA?


The codebase heavily uses PyTorch XLA libraries (torch_xla.*), which are specific to TPU. Key TPU-specific elements include XLA device initialization, SPMD execution mode, TPU-specific data loading, and mesh-based model partitioning.

[0] https://github.com/felafax/felafax/blob/main/llama3_pytorch_...

[1] https://pytorch.org/xla/master/


I’m surprised how it’s only 30% cheaper vs nvidia. How come? This seems to indicate that the nvidia premium isn’t as high as everybody makes it out to be.


30% is a conservative estimate (to be precise, we went with this benchmark: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blo...). However, the actual difference we observe ranges from 30-70%.

Also, calculating GPU costs is getting quite nuanced, with a wide range of prices (https://cloud-gpus.com/) and other variables that makes it harder to do apples-to-apples comparison.


Did you try running this task (finetuning Llama) on Nvidia GPUs? If yes, can you provide details (which cloud instance and time)?

I’m curious about your reported 30-70% speedup.


I think you slightly misunderstood, and I wasn't clear enough—sorry! It's not a 30-70% speedup; it's 30-70% more cost-efficient. This is mainly due to non-NVIDIA chipsets (e.g., Google TPU) being cheaper, with some additional efficiency gains from JAX being more closely integrated with the XLA architecture.

No, we haven't run our JAX + XLA on NVIDIA chipsets yet. I'm not sure if NVIDIA has good XLA backend support.


Then how did you compute the 30-70% cost efficiency numbers compared to Nvidia if you haven’t run this Llama finetuning task on Nvidia GPUs?


Check out this benchmark where they did an analysis: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blo....

At the bottom, it shows the calculations around the 30% cost efficiency of TPU vs GPU.

Our range of 30-70% is based on some numbers we collected from running fine-tuning runs on TPU and comparing them to similar runs on NVIDIA (though not using our code but other OSS libraries).


It would be a lot more convincing if you actually ran it yourself and did a proper apples to apples comparison, especially considering that’s the whole idea behind your project.


It's also comparing prices on google cloud, which has its own markup, a lot more expensive than say runpod. Runpod is $1.64/hr for the A100 on secure cloud while the A100 on Google is $4.44/hr. A lot more expensive... yeah. So in that context a 30% price beat is actually a huge loss overall.


who trains on a100 at this point lol


It's the chosen point of comparison on the linked paper.


Totally agree, thanks for feedback! This is one of the TODOs on our radar.


Nvidia margin is like 70%. Using google TPU is certainly going to erase some of that.


They sell cards and they are selling out


an interesting thread with speculation about how to eventually do this on local TPUs with llama.cpp and GGUF infrastructure: https://www.reddit.com/r/LocalLLaMA/comments/12o96hf/has_any...


That’s not happening. The coral edge tpus are ancient, slow and don’t have enough mem to be meaningful and somehow still manage to be relatively expensive even 2nd hand.

They have some good uses but LLMs aint it


Are those the TPUs Google sells to consumers? I've been thinking of buying one & hooking it up to a Pi just to play around with LLMs or Stable Diffusion. But I didn't realize they were slower/worse than other options.


The Coral TPUs have not been updated for several years. They were last updated long before the current LLM craze. They are good for simple things like object detection in photos.

They have almost nothing in common with Cloud TPUs.


Ahh, the reddit thread is referring to edge TPU devices, will check it out.

Google also has Cloud TPUs, which are their server-side accelerators, and this is what we are initially trying to build for!


For 99% case flash is enough. Period.


You might want to change Road Runner logo because it’s definitely copyrighted


Haha, yeah, good point. I'll remove it.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: