Hi everyone,
I am a bit unsure what the current best practice is for torch 2.4 and the future when it comes to custom autograd implementations. The syntax I’m used to is to wrap custom functions in torch.autograd.Function, as described here: Extending PyTorch — PyTorch 2.4 documentation.
But now, there is also the torch.library
implementation with torch.library.register_autograd, described here: Python Custom Operators — PyTorch Tutorials 2.4.0+cu121 documentation.
The actual use-case I have are custom triton function, which work nicely with the old implementation for normal things, but, for example, don’t work well with fake tensors, like the ones used for flop counting in torch.utils.flop_counter
.
Concrete questions:
- What is the best practice for optimal
torch.compile
support? - If the custom_op is only pytorch code + triton code, does it matter?
- If I want a fake tensor stub, or register that my op works only on cuda, should I switch to
torch.library
? - If I use the
torch.library
custom_op, does that disable torch.compile fusion into the custom operator? (E.g. merging a reshape in the custom_op into a previous operation outside the custom_op) - How do I annotate support for
torch.amp
withtorch.library
, so the same thing that@torch.amp.custom_fwd
does fortorch.autograd.Function
?
Thanks!