Symbolic trace with *args

Hi all.

I’m trying to trace model with positional arguments *args in forward.
The simplified problem is Container that forwards positional arguments to its Module.
When I pass Container to fx.symbolic_trace I get the following error:

torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when used in a for loop or as a *args or **kwargs function argument.

This happened because all arguments with prefix * were replaced by Proxy.

What can be done to solve this problem?

from torch import Tensor
import torch.nn as nn
import torch.fx as fx

class Module(nn.Module):
    def forward(self, x):
        return x

class Container(nn.Module):
    def __init__(self):
        self._module = Module()

    def forward(self, *args):
        return self._module(*args)

model = Container()

trace = fx.symbolic_trace(model)  # Error.