Choice of torch.compile vs. triton

Hi, I’m new to torch.compile and the doc says

torch.compile is the latest method to speed up your PyTorch code! torch.compile makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes.

Does that mean if I use torch.compile on models/functions, it gives similar optimization of kernel fusion with triton?

What’s the difference between torch.compile and triton? Is either one sufficient for optimization, or can they be used orthogonally at the same time?

I’m new to all this kind of optimizations and super curious, any suggestions or references would be greatly appreciated! Thanks!

On GPUs torch.compile() will apply various compiler optimizations with the most important ones being cuda graphs and fusions. Fusions specifically are done by code generating triton kernels so torch.compile() is actually a great starting point for writing triton kernels and you can always take that and try to write something faster yourself.

From a recent talk I gave

1 Like

@marksaroufim Thank you! Would it be possible to share the talk?

torch.compile() is actually a great starting point for writing triton kernels and you can always take that and try to write something faster yourself.

I don’t quite follow here: does that mean that torch.compile() can generate human-readable triton kernel files, and people can make manual tweaks from there and put them back to the model definition? My usage of torch.compile is simply wrapping a model with it during training.

Yes that’s right, this is a link to the original talk https://www.youtube.com/watch?v=LuhJEEJQgUM&t=1s

It’s not that common for end users to modify the triton code but for pytorch compiler devs it’s quite useful

1 Like

@marksaroufim May I ask to what extent torch.compile optimizes the computation for users? Given how powerful torch.compile is, should we always use it first before writing triton/CUDA kernels for extra efficiency?

I found many tutorials (triton, torch.compile, CUDA, etc) about making the model more efficient but just not sure if they are interchangeable.

The main optimizations torch.compile will perform are

  1. Fusions
  2. CUDA graphs
  3. Template matching for SDPA (i.e: flash attention)

It’s also an active project so I’m sure there’s more in flux. Can’t say what you should always do but personally yes I always reach for torch.compile first before attempting to write a custom kernel

1 Like

Thank you very much!