How to iterate a ParameterList in TorchScript?

Hi,

Is it possible to iterate a ParameterList in TorchScript? I was surprised to see that even the simple example from the PyTorch documentation does not work, i.e., this snippet

# from https://pytorch.org/docs/stable/generated/torch.nn.ParameterList.html
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList(
            [nn.Parameter(torch.randn(10, 10)) for i in range(10)]
        )

    def forward(self, x):
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x


if __name__ == "__main__":
    f = MyModule()
    f = torch.jit.script(f)  # BOOM

works fine if not JITed but fails otherwise:

[...]
Only constant Sequential, ModueList, or ModuleDict can be used as an iterable:
  File "[...]", line 14
    def forward(self, x):
        for i, p in enumerate(self.params):
                    ~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

Are there any workarounds? In particular, I have a ModuleList and a parameter list of the same length. During the evaluation I want to pass the n-th parameter to the n-th module:

    def __init__(self, n: int) -> None:
        super().__init__()

        [...]
        self.my_modules = nn.ModuleList([...])  # n modules
        self.meta_parameters = nn.ParameterList([...]) # n nn.Parameters

    def forward(self, x):
        # for f, p in zip(self.my_modules, self.meta_parameters):
        for i, f in enumerate(self.my_modules):
            x = f(x, self.meta_parameters[i])

        return x

Thanks for your help!
Nis

should be fixed in [JIT] support parameterlist iteration by davidberard98 · Pull Request #76140 · pytorch/pytorch · GitHub

2 Likes