What is the current/future best practice for custom autograd functions (with triton parts)

Hi everyone,

I am a bit unsure what the current best practice is for torch 2.4 and the future when it comes to custom autograd implementations. The syntax I’m used to is to wrap custom functions in torch.autograd.Function, as described here: Extending PyTorch — PyTorch 2.4 documentation.
But now, there is also the torch.library implementation with torch.library.register_autograd, described here: Python Custom Operators — PyTorch Tutorials 2.4.0+cu121 documentation.

The actual use-case I have are custom triton function, which work nicely with the old implementation for normal things, but, for example, don’t work well with fake tensors, like the ones used for flop counting in torch.utils.flop_counter.

Concrete questions:

  • What is the best practice for optimal torch.compile support?
  • If the custom_op is only pytorch code + triton code, does it matter?
  • If I want a fake tensor stub, or register that my op works only on cuda, should I switch to torch.library?
  • If I use the torch.library custom_op, does that disable torch.compile fusion into the custom operator? (E.g. merging a reshape in the custom_op into a previous operation outside the custom_op)
  • How do I annotate support for torch.amp with torch.library, so the same thing that @torch.amp.custom_fwd does for torch.autograd.Function ?

Thanks!

1 Like