Compiling a method other than forward

Hi everyone! Was wondering if there was any way to compile a non-forward method of an nn.Module, i.e. something similar to what torch.jit.export does. To the best of my knowledge, only the forward method gets compiled out-of-the-box, and I wasn’t able to find any docs on the subject. Would greatly appreciate any help!

You should be able to compile other methods on your nn module by decorating them directly, e.g.:

class Mod(torch.nn.Module):
    def __init__(self):
        ....
    def forward(self):
        ....

    @torch.compile
    def my_method(self, x)
        ....

m = Mod()
x = torch.randn(4, 4)
# my_method will get just-in-time-compiled
out = m.my_method(x)

One thing to call out is: what are you trying to do? Normally if you just do m = torch.compile(m), this will effectively end up causing m.__call__ to be compiled (which is a thin wrapper around the forward that also executes nn module hooks). So if you end up calling your method from inside of the your Module.forward, then compiling the entire module will cause its forward, as well as your custom method to get inlined into a single compiled region

1 Like

Thank you for your answer! I am working with several modules without a forward method, so compiling them separately seems to be the way to go.