Hi,
I am trying to script a PyTorch model. I am running into the following issues;
My original code utilizes subscripts as follows:
for i in range(self.n_layers):
self.in_layers[i](a),
self.cond_layers[I](b) ...
Both in_layers and cond_layers are of type torch.nn.ModuleList().
When trying to compile I receive an error that these types are not subscriptable. To work around this issue, I created a new function that receives a torch.nn.ModuleList and iterates over each item to find the right one. I am trying to use MyPy annotations to get this to compile as follows:
def moduleListChoice(self, ml, choice):
# type: (torch.nn.ModuleList, int)
moduleNumber = 0
for module in ml:
if moduleNumber == choice:
return module
moduleNumber = moduleNumber + 1
print(“Module not found”)
The problem is that I can’t figure out what the type should be for the first argument. torch.nn.ModuleList doesn’t work. Any thoughts on how get this to work?
Thanks.