ModuleList typing error: not an iterable

If I use a ModuleList:

import torch.nn as nn

class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.module_list = nn.ModuleList(
            [nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 1)]
        )

    def forward(self, batch):
        for i in self.module_list[1:4]:
            pass

I get a typing error: Expected type ‘Iterable’ (matched generic type ‘Iterable[_T1]’), got ‘Module’ instead.
How do I fix this?
(I would like to avoid creating multiple nn.ModuleLists)

Thanks.

Best

Pascal

Which PyTorch version are you using as your code works in my setup (1.11.0.dev20211101+cu113):

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.module_list = nn.ModuleList(
            [nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 1)]
        )

    def forward(self, batch):
        for i in self.module_list[1:4]:
            print(i)
        return batch

model = Model()
out = model(torch.randn(1, 1))
print(out.shape)
> torch.Size([1, 1])

Thanks for your quick reply.
I am using ‘1.10.0+cpu’.
Code also runs for me, but I get a typing error. i.e. mypy would complain.

Ah, OK, sorry as I misunderstood the error.
Unfortunately, I don’t know how to fix the mypy typing error/warning.

Took a look into the code, I guess this is a bug in:
/nn/modules/container.py

def __getitem__(self, idx: int) -> Module:
        if isinstance(idx, slice):
            return self.__class__(list(self._modules.values())[idx])
        else:
            return self._modules[self._get_abs_string_index(idx)]

The return type is actually not Module but Union[Module, ‘ModuleList’].
Should I open a PR for this?

Yes, I’m sure your fix is more than welcome! :slight_smile: