Hi,
Is it possible to iterate a ParameterList in TorchScript? I was surprised to see that even the simple example from the PyTorch documentation does not work, i.e., this snippet
# from https://pytorch.org/docs/stable/generated/torch.nn.ParameterList.html
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterList(
[nn.Parameter(torch.randn(10, 10)) for i in range(10)]
)
def forward(self, x):
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return x
if __name__ == "__main__":
f = MyModule()
f = torch.jit.script(f) # BOOM
works fine if not JITed but fails otherwise:
[...]
Only constant Sequential, ModueList, or ModuleDict can be used as an iterable:
File "[...]", line 14
def forward(self, x):
for i, p in enumerate(self.params):
~~~~~~~~~~~~~~~~~~~~~ <--- HERE
x = self.params[i // 2].mm(x) + p.mm(x)
return x
Are there any workarounds? In particular, I have a ModuleList and a parameter list of the same length. During the evaluation I want to pass the n-th parameter to the n-th module:
def __init__(self, n: int) -> None:
super().__init__()
[...]
self.my_modules = nn.ModuleList([...]) # n modules
self.meta_parameters = nn.ParameterList([...]) # n nn.Parameters
def forward(self, x):
# for f, p in zip(self.my_modules, self.meta_parameters):
for i, f in enumerate(self.my_modules):
x = f(x, self.meta_parameters[i])
return x
Thanks for your help!
Nis