Given a torch.nn.Linear
module, I noticed that the aot_autograd
transform the torch IR captured by dynamo into torch.ops.aten.addmm.default
, while the pre_dispatch_eager
backend transform it into torch.ops.aten.linear.default
.
I suspect this might have something to do with the pre_dispatch
, but I am not full understand the concept of pre_dispatch
and how it work. Could someone provide some guidance on where to find documentation or code examples that explain pre_dispatch
in more detail?
Thank you very much for any help!
May related owner @ezyang @Chillee @bdhirsh (please ignore it if not )
from typing import List
from torch._dynamo.backends.common import aot_autograd
import torch
torch._dynamo.config.suppress_errors = True
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(5, 10)
self.fc2 = torch.nn.Linear(5, 10)
def forward(self, inputs):
a = self.fc1(inputs)
b = self.fc2(inputs)
return a * b
# Use pre_dispatch_eager
user_model = Bar()
compiled_func = torch.compile(user_model, backend="pre_dispatch_eager")
print(compiled_func)
inp1 = torch.randn(1, 5)
inp2 = torch.randn(1, 5) * -1
out = compiled_func(inp1)
out2 = compiled_func(inp2)
print(out)
print(out2)
# Use aot_autograd
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("custom backend called with FX graph:")
gm.print_readable()
return gm.forward
aot_backend = aot_autograd(fw_compiler=custom_backend)
compiled_func = torch.compile(Bar(), backend=aot_backend)
print(compiled_func)
inp1 = torch.randn(1, 5)
inp2 = torch.randn(1, 5) * -1
out = compiled_func(inp1)
out2 = compiled_func(inp2)
print(out)
print(out2)