import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.module_list = nn.ModuleList(
[nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 1)]
)
def forward(self, batch):
for i in self.module_list[1:4]:
pass
I get a typing error: Expected type ‘Iterable’ (matched generic type ‘Iterable[_T1]’), got ‘Module’ instead.
How do I fix this?
(I would like to avoid creating multiple nn.ModuleLists)