I have the following model classes:
class BaseClass(nn.Module):
...
class MyModel(BaseClass):
def __init__(self):
super().__init__()
self.transformer = CustomTransformer(10, 10)
self.linears = ...
def forward(self, x):
return self.transformer(x)
def compute_predictions(self, x):
# mask x
return self(x)
class CustomTransformer(BaseClass):
def __init__(self, d_model, nhead):
super().__init__()
self.layers = ...
def forward(self, x):
return self.layers(x)
Usually I call compute_predictions
method from outside, which then performs some masking/indexing over the input x, and then call the forward
function, which has the most time consumption.
I am trying to benchmark the improvements that torch compile provides using torch version=‘2.5.1+cu124’, and I using the following code:
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
model = MyModel()
compiled_temporary_model = torch.compile(model)
with autocast: # autocast to float16
print(benchmark_torch_function_in_microseconds(model.compute_predictions, batch))
print(benchmark_torch_function_in_microseconds(compiled_temporary_model.compute_predictions, batch))
print(benchmark_torch_function_in_microseconds(torch.compile(compiled_temporary_model.compute_predictions), batch))
The output is as follows:
89699.9528631568
89938.04324418306
75973.31081827481
I would expect that the second line, would still have improvements, because of the forward method being implicitly compiled. In addition, if I try to benchmark the inner transformer layer using directly the forward method:
with autocast:
print(benchmark_torch_function_in_microseconds(model.transformer, input_))
print(benchmark_torch_function_in_microseconds(compiled_temporary_model.transformer, input_))
print(benchmark_torch_function_in_microseconds(compiled_temporary_model.transformer.forward, input_))
print(benchmark_torch_function_in_microseconds(torch.compile(compiled_temporary_model.transformer), input_))
Output:
86728.8438603282
87124.97558444738
87061.39624118805
74922.20476269722
Same results when I explicitly compile the submodule of MyModel, as follows:
def compile_submodules(model):
for name, module in model.named_children():
if len(list(module.children())) > 0:
compile_submodules(module)
else:
setattr(model, name, torch.compile(module))
return model
compiled_temporary_model = compile_submodules(model)
I want to better understand why this is happening. For now I am fixing the issue by explicitly compiling the forward function as below:
compiled_model = torch.compile(model)
compiled_model.forward = torch.compile(compiled_model.forward)