Forward function not being compiled by default

I have the following model classes:

class BaseClass(nn.Module):
    ...
    
class MyModel(BaseClass):
    def __init__(self):
        super().__init__()
        
        self.transformer = CustomTransformer(10, 10)
        self.linears = ...
        
    def forward(self, x):
        return self.transformer(x)

    def compute_predictions(self, x):
        # mask x
        return self(x)

class CustomTransformer(BaseClass):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.layers = ...
    
    def forward(self, x):
        return self.layers(x)

Usually I call compute_predictions method from outside, which then performs some masking/indexing over the input x, and then call the forward function, which has the most time consumption.
I am trying to benchmark the improvements that torch compile provides using torch version=‘2.5.1+cu124’, and I using the following code:

def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

model = MyModel()
compiled_temporary_model = torch.compile(model)
with autocast:  # autocast to float16
    print(benchmark_torch_function_in_microseconds(model.compute_predictions, batch))
    print(benchmark_torch_function_in_microseconds(compiled_temporary_model.compute_predictions, batch))
    print(benchmark_torch_function_in_microseconds(torch.compile(compiled_temporary_model.compute_predictions), batch))

The output is as follows:

89699.9528631568
89938.04324418306
75973.31081827481

I would expect that the second line, would still have improvements, because of the forward method being implicitly compiled. In addition, if I try to benchmark the inner transformer layer using directly the forward method:

with autocast:
    print(benchmark_torch_function_in_microseconds(model.transformer, input_))
    print(benchmark_torch_function_in_microseconds(compiled_temporary_model.transformer, input_))
    print(benchmark_torch_function_in_microseconds(compiled_temporary_model.transformer.forward, input_))
    print(benchmark_torch_function_in_microseconds(torch.compile(compiled_temporary_model.transformer), input_))

Output:

86728.8438603282
87124.97558444738
87061.39624118805
74922.20476269722

Same results when I explicitly compile the submodule of MyModel, as follows:

def compile_submodules(model):
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            compile_submodules(module)
        else:
            setattr(model, name, torch.compile(module))
    return model
compiled_temporary_model = compile_submodules(model)

I want to better understand why this is happening. For now I am fixing the issue by explicitly compiling the forward function as below:

compiled_model = torch.compile(model)
compiled_model.forward = torch.compile(compiled_model.forward)

Wrapping your model in torch.compile returns an OptimizedModule wrapping your module. OptimizedModule’s forward fn is now the compiled version, but any other attribute access will be forwarded to the inner original module. So what happens is that when you run compiled_mod.compute_predictions, self is actually still the original MyModule() rather than OptimizedModule.

1 Like