Any recommendations about how to use a PyTorch trained model for inference? Is it best to load it up with PyTorch directly, or convert to ONNX and use ONNX-runtime [1] instead? This seems to be the required way at least if you want to TensorRT the model. I appreciate this is a very general question.
I think using TRTorch[1] can be quick way to generate both easy to use and fast inference models from PyTorch.
It compiles your model, using TensorRT, Ahead of Time and enables you to use the compiled model through torch.jit.load("your_trtorch_model.ts") in your application.
Once compiled, you no longer need to keep your model's code in the application (as for usual jit models).
The inference time is on par with TensorRT and it does the optimizations for you as well.
You can quantize your model to FP16 or Int8 using PTQ as well and it should give you an additional speed up inference wise.
There's another level of speed you can unlock by combining with https://pytorch.org/docs/master/notes/cuda.html#cuda-graphs. i got (i kid you not) 20x speed on batch size = 1 inference by first using tensorrt to fuse kernels and then "graphing". and even for larger batch size it's just free perf gains
the model that i got 20x on is very simple - just a couple of convs and relus - it's for edge detection on a pseudo-embedded platform (jetson) - but the wins from cuda graphs are from two things: complete elimination of kernel individual launch times and complete elimination of allocations for intermediate tensors, which dominate runtime for small kernel sizes (e.g. batch size = 1).
That is so cool !
May I ask at which resolution you had those results ?
We managed to get up to 10x for very low resolutions (160) for a resnet101 but it usually plateaus for high resolutions (above 896x896) at a 1.7~1.9 speed-up.
Although using Int8 gives even higher speed-ups (~times 3.6 for 896x896 input), for some tasks it degrades the performance too much.
As someone who's been vaguely interested in PyTorch inference optimization but has never had a clear jumping-in point, thank you for this comment! Nice to see a clear two-sentence explanation that actually makes sense to me, makes me really want to try out TRTorch and TensorRT!
As far as I know, the ONNX format won't give you a performance boost on its own. However, there are ONNX optimizers for the ONNX runtime which will speed up your inference.
But if you are using Nvidia Hardware, then TensorRT should give you the best performance possible, especially if you change the precision level. Don't forget to simplify your ONNX model before you converting it to TensorRT though: https://github.com/daquexian/onnx-simplifier
Main thing you want for server inference is auto batching. It's a feature that's included in onnxruntime, torchserve, nvidia triton inference server and ray serve.
Related to this question, can someone explain the design goal of torch.jit to me? Is it supposed to boost performance or just give a means to export models? I found my jitted code ran slower than interpreted pytorch, and the latter despite its asynchronous nature spent most of its time waiting for the next gpu kernel to start.
Having got a working torch model on cpu, what's the best path to actually making it run as fast as I feel it has potential to?
It’s both. torch.jit started life as an optimizer. I think fusion of pointwise kernels on GPU - which we finally extended to CPU in this release - was one of the early wins via jit.
But at some point it became a model export format for production environments that can’t use CPython for performance reasons.
I’m surprised that you’re seeing worse performance with jit. It sometimes takes 20-ish iterations for the jit to “settle down” but I’d expect roughly equal performance at worst. If you can share a repro, I’d be happy to take a quick look if you file an issue on GitHub. (I’m @bertmaher there)
you know what i still don't understand? what's taking so long to warm up? i see that there are graph passes that run to do various things at the TS IR level, but I don't see any stats being collected (on shapes) or something like that that then inform further optimization.
There’s a “profiling graph executor” that records shapes and then hands them off to a fusion compiler. The profiling executor re-specializes on every new shape it sees, but stops at 20 re-specialization.
We’re working on eliminating the dependence on shape specialization right now, since it’s kind of an unfortunate limitation for some workloads.
Oh also to answer your “as fast as possible” question: usually you’ll get the best performance by exporting your model to a perf-tuned runtime. We’ve seen really good results with TensorRT and (for transformers) FasterTransformer. I’ve also seen good results with ONNX runtime.
Staying within pytorch, we recently added torch.jit.optimize_for_inference (I think it’s in 1.10, though not entirely sure) that can apply a bunch of standard optimizations to a model and often provides some nice wins.
Torch.jit shouldn't impact your performance positively or negatively in my experience. Although I've only used it on cpu. It's as far as I know just used for model exports.
The nice thing about it though, is that you can embed native python code (that's compiled to c++) into the model artifact. It's allowed us to write almost all of the serving logic of our models very closely to the model code itself, giving a better overview than having the server logic written in a separate repo.
The server we use on top of this can be pretty "dumb", and just funnel all inputs to the model, which the Python code determines what to do with.
As for model speedups, maybe you should look into quantization? I also find that there's usually lots of low hanging fruit if you go over code and rewrite to quicker ops which are mathematically equivalent, but allocate less memory, or do less ops.
It makes it possible to lift your model out of python while handling programming constructs like loops and if statements (see torch.jit.script).
It also makes it possible to sidestep the GIL and remove the overhead of launching kernels from python, which only really makes a noticeable difference with models that queue up a lot of small operations on the GPU. (LSTMs are an example of where this would make a difference https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscrip...)
PyTorch's serialized models are just Python pickle files. So to load those you need the original classes that were used to build the model. By converting to ONNX you get rid of those dependencies.
Your mileage will certainly vary but I was able to eek out a lot more inference performance by exporting to ONNX, using a highspeed serving framework in that ecosystem and also relying on some computation graph optimizations that you can apply to the ONNX version using available community tools. Versus serving from PyTorch directly.
We were doing millions of inferences and we had a specific target of a couple thousand a second so a specific case for sure but that's my two cents.
- Pytorch-Biggraph is specifically using torch.distributed with gloo (with an MPI backend).
So here's the question - if ur a 2 person startup that wants to do Pytorch distributed training using one of the cloud-managed EKS/AKS/GKE services... what should you use ?
The pytorch lightning people have come up with grid.ai, I personally have obtained good results by using pytorch lightning plus slurm on HPC machines. If I were a startup, I would probably try to build my own small HPC cluster, since that is far more cost effective than renting.
so most early stage startups get tens of thousands of dollars of free AWS credits. https://aws.amazon.com/activate/
100K if ur part of a university accelerator.
it is far far more efficient (as a proportion of time-to-market) to rent and build on top of services.
Kubernetes is where the wider ecosystem is. I dont like it ...but it is what it is.
So Grid.ai is something like AWS Sagemaker. I wanted to figure out what someone can use on a readymade kubernetes cluster.
During my PhD I burned through ~10^5 of compute hours on 4-8 V100 GPU, 128 Core, 1TB+ HPC nodes. Your 10k AWS credits run out really quickly compared to that. And as a startup you are far better of spending some money on a few beefy machines before renting them.
I would agree with you if the case was some sort of web service in general, but a GPU server that is only used in-house to run some common pyTorch training jobs is very simple to buy and install and run. Those cloud credits are going to get eaten fast as its a few $/hour, assuming you can get them (those programs are only available if you get your financing from a participating incubator or investor). But yeah if you do get 100k for free from someone you should of course use them first :)
Check out Determined (https://github.com/determined-ai/determined). It supports deploying onto k8s and handles running horovod (and soon other dtrain backends), with most of the complexity abstracted behind a few configuration values. Also, it gives you stuff like experiment tracking / hp search (asha) /scheduling / profiling and etc.
I’ve largely moved to jax, but it looks like pytorch is maybe moving that direction with torch.fx? The docs on it aren’t really clear though. Has anyone used it?
IMO, FX is more of a toolkit for writing transforms over your FX modules than "moving in Jax's direction" (although there are certainly some similarities!)
It's not totally clear what "Jax's direction" means to you, but I'd consider its defining characteristics as 1. composable transformations, 2. a functional way of programming (related to its function transformations)
Stuff like vmap and grad and pmap and all the rest have been a huge boon in simplifying some of my work, so I'm glad to see it's expanding into pytorch!
By “backend” I mean the compiler logic backing traced tensors in JAX.
torch.fx seems different in that it’s primarily aimed at being a platform for users, rather than JAX which (as far as I know) hides the logic behind @jit annotations.
So (to me) it seems that they are similar backends/library primitives with different front-ends. There doesn’t seem to be a difference in representational power, since both hit a graph representation. The main exception I could see would be something like timers, which would perhaps require a graph-mode equivalent for JAX.
With this kind of stuff, I think the devil is in the details, but the principle is similar, yes. Specifically, the level of abstraction at which you're tracing, as well as what types of programs you can express/not express.
For example, FX is extremely unopinionated about what it can trace, and the trace itself is extremely customizable. For example, if a subfunction/module has control flow (i.e. untraceable), it's easy to mark it as a "leaf" in FX's tracer, while that concept doesn't really make in sense in Jax's tracing system.
Another example of a difference is that Jax traces out into its own IR called a jaxpr, while FX is explicitly a Python => Python translation. This has some some upsides and some downsides - for example, you can insert arbitrary Python functions into your FX graph (breakpoints, print statements, etc.), while jaxprs don't allow that.
Is this a good thing? Well, if your main goal is to lower to XLA, definitely not lol. But for FX it works quite well.
TL;DR: The general principles of doing graph capture are similar, but the details matter, and the details end up being quite different.
Can someone please answer this? I'm so curious. The only real life use case that I've seen mentioned is "programmatically generating models, for example from a config file".
But due to Python's dynamic nature, this is already possible. AllenNLP is a great example of that.
To give another take, Python's dynamic nature allows you to easily change class properties, methods, and so on. However, what is difficult with Python is to actually modify the executed code.
For example, say you had
def f(x):
return F.relu(x)
Now, you want to change all the activations in your network from a relu to gelu. This... is not so easy generically. You might be able to do it for your personal code, but what if you're importing a model from torchvision?
With FX though, you can simply trace out the graph, substitute the F.relu with a F.gelu, and be done in <10 lines of code!
Essentially, it gives you the freedom to perform transformations on your code (although it places limitations on what your code can contain, like no control flow).
Another example is the recent support from torchvision for extracting intermediate feature activations (https://github.com/pytorch/vision/releases/tag/v0.11.0). Like, sure, it was probably possible to refactor all of their code to enable users to specify extracting an intermediate feature, but it's much cleaner to do with FX.
Another example is a feature that was just added to torchvision which can take a classification model and extract the backbone for generating embeddings.
Thanks for this reply! Got some follow up questions if you don't mind being bothered ...
> It makes it possible to automate optimizing python models, adding things like conv and batch norm fusion for inference.
By "optimize", do you mean "reduce computational load", or "use Adam/SGD/whatever to minimize a loss function"? What is "conv and batch norm fusion"? How does FX help with any of this?
> It also allows you to plug in other ops to for example make quantization or profiling easier.
I can indeed see how it could make profiling easier. I'd love to get pointers/links as to quantization methods that would necessitate adding new ops.
> Another example is a feature that was just added to torchvision which can take a classification model and extract the backbone for generating embeddings.
Hasn't it always been possible to extract a certain set of weights from some `nn.Module`?
> What is "conv and batch norm fusion"? How does FX help with any of this?
Essentially, during inference, batch norm is simply a multiply and add operation. If this occurs after a convolution, then you can simply fold (i.e. "fuse") the batch norm into the convolution by modifying the convolution's weights. See https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.... for more details.
What the FX pass ends up looking like is:
1. Look for a convolution followed by a batch norm (where the convolution output is not used anywhere else).
2. Modify the convolution's weights to reflect the batch norm.
3. Change all users of the batch norm's outputs to use the convolution's output.
> Hasn't it always been possible to extract a certain set of weights from some `nn.Module`?
IIRC, what torchvision allows you to do now is to get a new model that simply computes the embedding.
For example, something like this (not actual code)
model = resnet18()
# Returns a new model that takes in the same inputs and returns the output at layer 3
model_backbone = create_feature_extractor(model, return_nodes=('layer3'))
Before, you'd need to manually extract out the modules that did what you wanted, and things could be tricky depending on which activation you wanted exactly. For example, if you wanted it to output one activation from inside each residual block, how would you do it?
The FX-based API allows you to simply specify what outputs you want, and it'll create a new model that provides those outputs.
This looks great! Excited to see more pretrained models made easily accessible through Torchvision, and nn.Module parameterization seems like a really intuitive and neat way of tackling the 'parameterize the parameters' problem effectively. Kudos to the team at PyTorch, I can't wait to start playing with the new features!
@vpj The only drawback I see is that much of the implementation is abstracted by your helper libraries. Not everyone wants to add an extra dependency layer. Otherwise the walkthroughs are super helpful.
Since the backbone of the implementation is packed away in the library import, I felt it didn't quite show the code & variable interaction well enough.
Don't get me wrong. It is useful & concise like you mentioned. But your target audience is beginners & adopters, and it makes it no different from another framework such as Fastai (I have major gripes with them. It has a much bottled-in experience)
To be true to walkthroughs, please consider designing helper functions rather than using your framework. Admittedly it may not be as beautiful, but eventually your users will be more appreciative of the extra mile you go into making things transparent & similar to PyTorch docs.
Your notebooks are very useful, thank you! May I suggest making their background white, or the color text less saturated? The keywords, function names etc. are very difficult to read, I have to paste the code in another editor.
Fast.ai is basically pytorch + loads of utility functions. By following the course/book, one to some extend learns what fastai does, and how it uses pytorch for practical stuff.
I went through the first version of FastAI (when it was Keras, torch?/tensorflow?) and forgot most of it because never did anything with Deep Learning. Then I did the FastAI V1 course again where they use FastAI library V2.
I really liked the first version of the course because I felt like Jeremy did an awesome job of balancing understanding the guts of using a DL library with getting stuff done. It was a tough course, but I felt like I really understood things.
I felt like the version with FastAI library V2 went too far into "Here are some commands you can use in the FastAI V2 library to to do this sexy thing with Deep Learning." I completed that course and really felt it should have been titled "A Course on the Fast AI V2 Library"
I recently purchased "Deep Learning with PyTorch" by Eli Stevens. I've been working through this book and feel like it explains things a lot more. I'm haven't finished the book with it, but I do like it so far.
I felt similarly about the fastai course. I'm doing the deeplearning.ai course (Andrew Ng) and I really like it. Warning though, it is bottom up, so you start by implementing neural nets with numpy to understand the internals, then it gets into TensorFlow. Just another resource if people are looking to learn deep learning and not pytorch in particular.
I second this but once you are comfortable with fast.ai it's a good idea to reproduce some things in pure PyTorch eventually because it'll give you a deep understanding and appreciation for fast.ai as well.
One thing with fastai that annoyed me when I built a project with it was that the v1 and v2 APIs are totally different, and if you google or search stackoverflow for help with something, I found it more likely to stumble on answers for the v1 API than the v2 API. I also didn’t find their documentation super helpful for more than the most basic things (though not all documentation can be as amazing as scikit-learn).
It seems everyone wants deep learning to work on M1 chips. At this point, I am quite sure there is enough interest for major frameworks to consider supporting it.
Small addtion: matrix multiplications (and other operations implemented through BLAS) do use the M1's AMX matrix co-processor (through the Apple Accelerate framework).
[1] https://github.com/microsoft/onnxruntime