You got to give it to the pytorch team, they're really great at bringing complex optimization schemes (mixed-precision, torch.compile, etc) down to a simple to use API. I'm glad I moved from TF/Kerasto Pytorch around 2018-2019 and never looked back. I'm eager to try this as well.
I've seen and ignored a lot of "pytorch good, tensorflow bad" takes in my time, but this is so egregiously wrong I can't help but chime in. Facilitating graph-level optimizations has been one of the most central tenets of tensorflow's design philosophy since its inception. The XLA compiler was designed in close collaboration with the tensorflow team and was available in the tensorflow API as far back as 2017. It's not an exaggeration to say that pytorch is 5+ years behind on this front. Before anyone invokes the words "pythonic" or "ergonomic", I'd like to note that the tensorflow 2 API for compilation is nearly identical to torch.compile.
11. notice that there's a unicode rendering error ("'" for apostrophe) on kernel_initializer and bias_initializer default arguments in the documentation, and wonder why on earth for such a high-level API one would want to expose lora_rank as a first class construct. Also, 3 out of the 5 links in the "Used in the guide" links point to TF1 to TF2 migration articles - TF2 was released 5 years ago.
To add onto this I feel like one of the hard things about TF is that there is like at least 3 ways to do everything because they have supported multiple APIs and migrated to eager. So if you find an example or an open source project it might not be for the flavor of tensorflow that your codebase is in.
I feel like that with every single Google api doc. if there's a variable called x, the documentation will be "variable to store x". and you need to create/supply 5 different resources before you can create an x, but these will each require 5 further things to be figured out before you can create one of them.
Re 6: TF/Keras team motivates random people to write long tutorials and be featured in the official site and their tutorial be included in the official guides. I have seen a lot of subpar devs/AI people write subpar tutorials and brag on twitter how their tutorials are included in the official Keras site.
Honestly, this example holds true for roughly half of the Python ecosystem; and you can square the level of frustration if it's also anything coming from Google.
(This pattern is relatively easy to understand: smart people creating something get their gratification from the creation process, not writing tedious documentation; and this is systemically embedded for people at Google, who are probably directly incentivised in a similar way.)
Tensorflow works really well in theory. In practice a lot less so. I saw someone spend months fighting Tensorflow to convert a production model from CPU to GPU inference with any sort of efficiency. Tons of issues due to bugs across versions, deprecations of features across versions, the graph optimizer shuffling data back to the CPU for no decent reason, etc. The person had no idea what was happening or why most of the time due to how black box Tensorflow was. This was a very senior ML engineer with a lot of Tensorflow experience.
Does tensorflow have a future? I doubt it. I don't think Google is really investing many resources into it (beyond the necessary maintainence to support whatever production models still depend on it). The cost of migrating from old TF to new TF was really large, half the projects that depend on TF that I try to use just break out of the box (only 1/4 of torch projects I try fail that way).
From what I can tell Google is moving in a direction that doesn't require tensorflow, and I don't see it gaining signficant adoption outside google, so it seems most likely we will simply see it deprecated in about 10 years. It's best to see it as a transitional technology that Jeff Dean created to spur ML development internally, which was mistakenly open sourced, and now, Jeff's reports typically use Jax or other systems.
I think tensorflow-datasets and tensorflow-serving are great, but for model development I think most people use JAX and then export it to a tensorflow SavedModel with Orbax.
But IIUC Jax also leverages XLA and for the purpose of this discussion the frontend matters only inasmuch people feel productive in using it. Whether that's TF or Jax.
> Facilitating graph-level optimizations has been one of the most central tenets of tensorflow's design philosophy since its inception.
Agreed of course but it's not like they came up with this approach from scratch. They seem to have just picked it up from Theano (now Aesara/PyTensor).
Tensorflow is a lot like IBM -- it deserves praise not because it's great in its current state, but for its contributions towards advancing the broader technological front to where it is today. Tensorflow walked so JAX could run, so to speak. Frankly, I don't really draw much of a distinction between the two frameworks since I really just use them as lightweight XLA wrappers.
Tensorflow started out as anything but lightweight. In my opinion it takes the cake for kludgiest framework I've ever worked with. So verbose, so little effort put into ergonomics. Even eager mode is not really valuable unless you're working on a legacy project.
+1. As someone who has tried to migrate multiple tf.function to torch.compile, tensorflow edge is not small in this. torch.compile still is highly highly experimental. Don't believe me, just go and look into github issues as torch maintainers try to figure why torch.compile makes code very unoptimal in lot of cases, or results in incomprehensible errors.
Hi! I'm Mark from the PyTorch team at Meta and work on torchao. If you have any questions about the library or really anything at all about performance, don't hesitate to ask!
A minor nitpick on the copy (and even then, it might just be me): I find "97% speedup" and "50% speedup" really hard to parse — a "30x speedup" or "97% reduction of time taken" immediately tell me what is being achieved!
Great results once I get my head around them, though!
That's why it's confusing: "2x speedup" would clearly indicate 200% of the current speed, so 97% speedup is unclear if it's a multiple (not because that would be a slow down), a reduction in time (which was my assumption) or an increase in speed (something per unit of time).
I guess you are right and it's probably the latter, but obviously better language would have avoided any doubt.
Hi Mark, the library looks cool, excited to try it out. Coincidentally I am starting work on a project that is investigating a lot of Post training quantization methods. I read the blog and I am curious to understand what kind of overheads are involved in quantizing a layer?
There's a bunch of overhead associated with PTQ - but TL;DR is that much of that overhead goes away when you're using `torch.compile()` and `torchao.autoquant()`
Essentially the latency overhead comes from quantizing and dequantizing weights and activations. For large layers this overhead is small because by quantizing your weights for example you reduce memory bandwidth pressure but for small layers the overhead of potentially looking up a table, reading scaling factors, quantization/dequantization and finally handling zero points might not be worth it.
However, even if such overhead exists you can still quantize your model and get it to be smaller it might not be faster is the problem. We solve the speed problem in 2 ways - `torch.compile()` will fuse operations like a dequant and matmul into a single kernel and `torchao.autoquant()` will do kernel level profiling to see whether a layer is actually made faster when quantizing and if not it skips quantizing that layer.
First off, well done, this looks exciting. I haven't had a chance to interact with the library yet — should torchao be seen as a dev-friendly quantization interface? I.e., if my team was working on new quantization techniques, does the API provide easy tooling for implementing and benchmarking new quantization algorithms? Or is this closer to a "toolbox of finished (generally) finished products"?
It's both! For this blog we decided to discuss our best end user facing numbers to keep things simple. We briefly hint at our contributor guide here https://github.com/pytorch/ao/issues/391 which does a tour of the APIs we provide developers implementing new algorithms
But we have had quantization algorithm developers such as HQQ or Autoround merge their code in to get composability and serialization for free. We view quantization algorithms as the top layer and going down you have quantized tensors, quant primitives like dequant/quant and finally basic dtypes like uint1-7 and float3-8. Personally why I spent so much time on AO was I was hoping we could make it easier for people to express their quantization algorithms in easy to read PyTorch code and if they must use custom kernels we also have some tutorials for how to integrate custom cuda and triton ops.
Most of those discussions have been happening on #torchao on discord.gg/gpumode so if you need to chat back and forth feel free to reach out to the team there otherwise Github also works.
Most of our performance relies on leveraging torch.compile which generates Triton kernels which run fast on CPU and GPU but not MPS since Triton does not support generating Metal kernels. So you lose the nice story of writing low bit code in pure PyTorch but also get it running fast.
In these cases the only path forward we have is writing custom Metal kernels and plugging those in. That work is still ongoing and we'll hopefully have more to share soon.
This might not be the right place for this question but, as someone who has made a couple very modest mps backend contributions, I'm curious why not add metal support to triton (or a fork if openai won't allow it) rather than maintain a whole separate backend?
Mostly comes down to what's fastest to develop, it's faster to write a few custom kernels than it is to develop a new compiler backend
Granted after more upfront effort compilers are just such a significant UX boost that indeed you are making me question why I don't spend more time working on this myself lol
But that's waiting for Blackwell to be released so we get the hardware support. SO recommendation for now would be to use either fp8 training or int8 training
There's different tradeoffs, spinning up a separate repo is what we call "out of core" vs having everything in PyTorch "in core"
Basically PyTorch is a large library where CI takes a long time to run which means merging code is hard and adding new dependencies is challenging and there are stringent constraints on BC breaking changes
Instead what torchao did and many other repos like torchtune, torchchat, torchtitan did was move out of core and it helps keep the core PyTorch library leaner with a smaller binary size and it really lets the team "out of core" focus on optimizing for their needs
Unfortunately the argument for what gets better changes over time, for example torch.compile initially a new repo called torchdynamo was built out of core to move fast but eventually merged back because everyone wanted to use it. Now torch.compile dev velocity is still quite fast and so now we have to tell people to use nightlies instead of official stable releases to which some people have asked me why don't you move torch.compile out of core
My 2c is the ecosystem will be much stronger and teams can move faster if they develop out of core so that's the tradeoff we picked for torchao. We managed to for example merge a few custom CPP kernels like fp6 or Marlin that would have challenging to motivate in core since those are still quite experimental and need to stand the test of time.
Pardon my ignorance, but how do matrix operations on quantized data work? Is hardware support needed?
AFAIU int4 matrix multiplication is supported by cuda, but I'm not sure about other operations. The blog post mentioned fp6, and I don't think this is supported by cuda. Or maybe the data are upscaled to something common like fp16 before doing math?
It's a great question! Int4 is an easy one to understand. PyTorch supports int8 but not int4 so what you can do is "pack" 2 int4 values into a single int8 value. You still get speedups even without hardware support because you're sending less data to the GPU and workloads like small batch size LLM inference are memory bandwidth bound and not compute bound. So indeed your intuition is correct you pack the values and before doing a matmul you "unpack" them back into an int8 and then upcast to fp16 to do a matmul
reply