Hi all.
I’m trying to trace model with positional arguments *args
in forward
.
The simplified problem is Container
that forwards positional arguments to its Module
.
When I pass Container
to fx.symbolic_trace
I get the following error:
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when used in a for loop or as a *args or **kwargs function argument.
This happened because all arguments with prefix *
were replaced by Proxy
.
What can be done to solve this problem?
from torch import Tensor
import torch.nn as nn
import torch.fx as fx
class Module(nn.Module):
def forward(self, x):
return x
class Container(nn.Module):
def __init__(self):
super().__init__()
self._module = Module()
def forward(self, *args):
return self._module(*args)
model = Container()
model(Tensor())
trace = fx.symbolic_trace(model) # Error.
Thanks.