I have a sample model defined.
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
I create an fx graph using the make_fx function from from torch.fx.experimental.proxy_tensor
fx_graph_sample = make_fx(mod)(torch.rand(100,100))
The make_fx command generates an fx_graph, stored in fx_graph_sample
.
Using the original model, and the fx_graph, the idea is to create a GraphModule.
torch.fx.GraphModule(mod, fx_graph_sample.graph)
The error I face is
AttributeError: ‘MyModule’ object has no attribute ‘_param_constant0’
When the fx_graph is generated using the make_fx function, iterating through the named_parameters of the model gives different parameter names than the ones in the original MyModule subclass. How do I solve the problem of generating the GraphModule?
PS: The same approach, when using symbolic_trace to generate the fx_graph, works perfectly.