Pascal1
(Pascal)
November 30, 2021, 10:31am
#1
If I use a ModuleList:
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.module_list = nn.ModuleList(
[nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 1)]
)
def forward(self, batch):
for i in self.module_list[1:4]:
pass
I get a typing error: Expected type ‘Iterable’ (matched generic type ‘Iterable[_T1]’), got ‘Module’ instead.
How do I fix this?
(I would like to avoid creating multiple nn.ModuleLists)
Thanks.
Best
Pascal
ptrblck
November 30, 2021, 10:37am
#2
Which PyTorch version are you using as your code works in my setup (1.11.0.dev20211101+cu113
):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.module_list = nn.ModuleList(
[nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 1)]
)
def forward(self, batch):
for i in self.module_list[1:4]:
print(i)
return batch
model = Model()
out = model(torch.randn(1, 1))
print(out.shape)
> torch.Size([1, 1])
Pascal1
(Pascal)
November 30, 2021, 10:39am
#3
Thanks for your quick reply.
I am using ‘1.10.0+cpu’.
Code also runs for me, but I get a typing error. i.e. mypy would complain.
ptrblck
November 30, 2021, 10:47am
#4
Ah, OK, sorry as I misunderstood the error.
Unfortunately, I don’t know how to fix the mypy typing error/warning.
Pascal1
(Pascal)
November 30, 2021, 11:01am
#5
Took a look into the code, I guess this is a bug in:
/nn/modules/container.py
def __getitem__(self, idx: int) -> Module:
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
The return type is actually not Module but Union[Module, ‘ModuleList’].
Should I open a PR for this?
ptrblck
December 1, 2021, 9:51am
#6
Yes, I’m sure your fix is more than welcome!