When I torch.export
the sample module in the official tutorial:
import torch
from torch.export import export
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod(torch.randn(8, 100), torch.randn(8, 100)))
The resulting graph doesn’t capture the default value of the alpha
parameter for the sum x + y
:
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, %arg3_1), kwargs = {})
By contrast if I explicitly pass in alpha
return torch.nn.functional.relu(self.lin(torch.add(x, y, alpha = 0.5)), inplace=True)
Then the graph correctly contains the third argument
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, %arg3_1), kwargs = {alpha: 0.5})
Not including the default value of the argument (when the argument is left unspecified by the caller) is a problem for backends that want to work with the FX graph, because it means that the backend has to reach out of the FX graph to find out the default value.
This is what e.g. the shark-turbine importer does in https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/importer.py#L609
Instead of iterating over e.g. the arguments of a call_function
node, they get a FunctionSchema
from the node (a protected field) and iterate over that.
This seems like a bug in torch.export
since it breaks encapsulation. Thoughts?