Hi everyone! Was wondering if there was any way to compile a non-forward method of an nn.Module, i.e. something similar to what torch.jit.export
does. To the best of my knowledge, only the forward
method gets compiled out-of-the-box, and I wasn’t able to find any docs on the subject. Would greatly appreciate any help!
You should be able to compile other methods on your nn module by decorating them directly, e.g.:
class Mod(torch.nn.Module):
def __init__(self):
....
def forward(self):
....
@torch.compile
def my_method(self, x)
....
m = Mod()
x = torch.randn(4, 4)
# my_method will get just-in-time-compiled
out = m.my_method(x)
One thing to call out is: what are you trying to do? Normally if you just do m = torch.compile(m)
, this will effectively end up causing m.__call__
to be compiled (which is a thin wrapper around the forward that also executes nn module hooks). So if you end up calling your method from inside of the your Module.forward
, then compiling the entire module will cause its forward, as well as your custom method to get inlined into a single compiled region
1 Like
Thank you for your answer! I am working with several modules without a forward method, so compiling them separately seems to be the way to go.