How to create module class with dynamic number of layers


I’ve run into this task several times when I want to work with a network design so I create a module class for it, but then I want to experiment with varying the number of layers. Unfortunately the simplest solution of just creating an array of the modules during init means that these sub modules will not be added to the children() of the main module (meaning I can’t use things like .cpu(), .cuda() which rely on the apply() method finding the children.)

I have inspected the nn.Module class and the only way I have come up with to get a list of modules into the tree of children is by creating an iterable nn.Module subclass that takes the list of modules during initialization and uses setattr directly to add the modules under the number as name.

I have put together an example of what i’m working (see below) but i’m not sure if i’m overthinking this, missing an obvious method for achieving this result, or if there is a serious issue i’m not seeing which what I’ve done. So I would like any input on the above uncertainties. If you have a good way to do this I would love to hear about it.

In particular I would love to know if there is a way to directly add to the children without requiring a name (I suspect this is impossible though).

My solution:

import torch
import torch.nn as nn

class dummy(nn.Module):
    def __init__(self):
        self.sample = nn.Parameter(torch.ones((5,5)))
    def forward(self, x, full=False):
        print (self.sample.device)
        if (full):
            return [x]
        return x
class ModuleList(nn.Module):
    def __init__(self, modules):
        self.register_buffer('N', torch.tensor(len(modules)))
        for i in range(len(modules)):
            self.__setattr__(i, modules[i])
    def __iter__(self):
        for i in range(self.N):
            yield self.__getattr__(i)
    def __len__(self):
        return self.N.item()
    def forward(self):
class TestModule(nn.Module):
    def __init__(self, N):
        self.N = N
        layers = []
        for i in range(N):
        self.layers = ModuleList(layers)
    def forward(self, x):
        y = x
        for l in self.layers:
            y = l(y)
        return y

M = TestModule(3)
y = M(torch.ones((2,2)))
y = M(torch.ones((2,2)))   

And as proof that this solution did work, the output of the last four lines (testing) gave:



I’m not sure, if I understand the question correctly, but nn.ModuleList should be able to register modules in a list.
Why did you implement your own ModuleList? Is nn.ModuleList not working or did you just miss this class?

Wow can’t believe I missed that and then named my implementation the same thing. Yeah this was exactly what I was looking for. Thanks!

Haha, that’s good to hear and apparently the naming is quite “intuitive” :wink: