List of nn.Module in a nn.Module

To complement @fmassa’s post, it fails because we only capture modules that are assigned directly to the Module object. It gets too tricky and bug prone otherwise. There are a number of tricks you can use to get around it, with ListModule shown above being one of them. If I were to suggest something, I’d keep all the modules in a single container like this:

class AttrProxy(object):
    """Translates index lookups into attribute lookups."""
    def __init__(self, module, prefix):
        self.module = module
        self.prefix = prefix

    def __getitem__(self, i):
        return getattr(self.module, self.prefix + str(i))


class testNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, steps=1):
        super(testNet, self).__init__()
        self.steps = steps
        for i in range(steps):
            self.add_module('i2h_' + str(i), nn.Linear(input_dim, hidden_dim))
            self.add_module('h2h_' + str(i), nn.Linear(hidden_dim, hidden_dim))
        self.i2h = AttrProxy(self, 'i2h_')
        self.h2h = AttrProxy(self, 'h2h_')

    def forward(self, input, hidden):
        # here, use self.i2h[t] and self.h2h[t] to index 
        # input2hidden and hidden2hidden modules for each step,
        # or loop over them, like in the example below
        # (assuming first dim of input is sequence length)
        for inp, i2h, h2h in zip(input, self.i2h, self.h2h):
            hidden = F.tanh(i2h(input) + h2h(hidden))
        return hidden
12 Likes