To complement @fmassa’s post, it fails because we only capture modules that are assigned directly to the Module object. It gets too tricky and bug prone otherwise. There are a number of tricks you can use to get around it, with ListModule
shown above being one of them. If I were to suggest something, I’d keep all the modules in a single container like this:
class AttrProxy(object):
"""Translates index lookups into attribute lookups."""
def __init__(self, module, prefix):
self.module = module
self.prefix = prefix
def __getitem__(self, i):
return getattr(self.module, self.prefix + str(i))
class testNet(nn.Module):
def __init__(self, input_dim, hidden_dim, steps=1):
super(testNet, self).__init__()
self.steps = steps
for i in range(steps):
self.add_module('i2h_' + str(i), nn.Linear(input_dim, hidden_dim))
self.add_module('h2h_' + str(i), nn.Linear(hidden_dim, hidden_dim))
self.i2h = AttrProxy(self, 'i2h_')
self.h2h = AttrProxy(self, 'h2h_')
def forward(self, input, hidden):
# here, use self.i2h[t] and self.h2h[t] to index
# input2hidden and hidden2hidden modules for each step,
# or loop over them, like in the example below
# (assuming first dim of input is sequence length)
for inp, i2h, h2h in zip(input, self.i2h, self.h2h):
hidden = F.tanh(i2h(input) + h2h(hidden))
return hidden