Hi everyone! Was wondering if there was any way to compile a non-forward method of an nn.Module, i.e. something similar to what torch.jit.export does. To the best of my knowledge, only the forward method gets compiled out-of-the-box, and I wasn’t able to find any docs on the subject. Would greatly appreciate any help!
You should be able to compile other methods on your nn module by decorating them directly, e.g.:
class Mod(torch.nn.Module):
def __init__(self):
....
def forward(self):
....
@torch.compile
def my_method(self, x)
....
m = Mod()
x = torch.randn(4, 4)
# my_method will get just-in-time-compiled
out = m.my_method(x)
One thing to call out is: what are you trying to do? Normally if you just do m = torch.compile(m), this will effectively end up causing m.__call__ to be compiled (which is a thin wrapper around the forward that also executes nn module hooks). So if you end up calling your method from inside of the your Module.forward, then compiling the entire module will cause its forward, as well as your custom method to get inlined into a single compiled region
Thank you for your answer! I am working with several modules without a forward method, so compiling them separately seems to be the way to go.
Suppose you have a class already implemented in an imported library and you cannot modify it. In this case, you can also directly apply torch.compile to the method.
In the example below, we call directly torch.compile on the decode method of MyClass.
import torch
import my_lib.MyClass
# Init and load the model
model = my_lib.MyClass()
# Call torch.compile on decode method
model.decode = torch.compile(model.decode) # Add your compile args