Dear Experts:
I am having trouble running Export IR through fx.Interpreter. I thought the two should be compatible because I read that export IR is based on FX IR and also code like torch.export.unflatten.InterpreterModule internally uses it. However, it seems like it is not the case (torch.export.unflatten.InterpreterModule also does not run correctly on my setup). Below is a simple test code I tried. The top runs correctly (FX IR + FX interpreter), but the bottom fails (Export IR + FX interpreter):
import torch
from torch import nn
from torch.export import export
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = nn.Conv2d(4, 10, 1, 1)
self.layer1 = nn.Conv2d(10, 5, 1, 1)
def forward(self, x):
return self.layer1(self.layer0(x))
x = torch.rand([7, 4, 32, 32])
net = MyModule()
# This runs correctly
gm = torch.fx.symbolic_trace(net)
ip = torch.fx.Interpreter(net, graph=gm.graph)
ip.run(x)
# This doesn't
exported_program = export(net, args=(x,))
ip = torch.fx.Interpreter(net, graph=exported_program.graph)
ip.run(x)
The error message is:
Traceback (most recent call last):
File "/.../test.py", line 46, in <module>
ip.run(x)
File "/.../lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/.../lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../lib/python3.11/site-packages/torch/fx/interpreter.py", line 236, in placeholder
raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
RuntimeError: Expected positional argument for parameter p_layer0_bias, but one was not passed in!
If I am doing anything wrong, it would be great if anyone can help me understand what I am doing incorrectly.