I’ve run into this task several times when I want to work with a network design so I create a module class for it, but then I want to experiment with varying the number of layers. Unfortunately the simplest solution of just creating an array of the modules during init means that these sub modules will not be added to the children() of the main module (meaning I can’t use things like .cpu(), .cuda() which rely on the apply() method finding the children.)
I have inspected the nn.Module class and the only way I have come up with to get a list of modules into the tree of children is by creating an iterable nn.Module subclass that takes the list of modules during initialization and uses setattr directly to add the modules under the number as name.
I have put together an example of what i’m working (see below) but i’m not sure if i’m overthinking this, missing an obvious method for achieving this result, or if there is a serious issue i’m not seeing which what I’ve done. So I would like any input on the above uncertainties. If you have a good way to do this I would love to hear about it.
In particular I would love to know if there is a way to directly add to the children without requiring a name (I suspect this is impossible though).
import torch import torch.nn as nn class dummy(nn.Module): def __init__(self): super().__init__() self.sample = nn.Parameter(torch.ones((5,5))) def forward(self, x, full=False): print (self.sample.device) if (full): return [x] return x class ModuleList(nn.Module): def __init__(self, modules): super().__init__() self.register_buffer('N', torch.tensor(len(modules))) for i in range(len(modules)): self.__setattr__(i, modules[i]) def __iter__(self): for i in range(self.N): yield self.__getattr__(i) def __len__(self): return self.N.item() def forward(self): print ("DON'T CALL FORWARD ON THE MODULE LIST") class TestModule(nn.Module): def __init__(self, N): super().__init__() self.N = N layers =  for i in range(N): layers.append(dummy()) self.layers = ModuleList(layers) def forward(self, x): y = x for l in self.layers: y = l(y) return y M = TestModule(3) y = M(torch.ones((2,2))) M.cuda() y = M(torch.ones((2,2)))
And as proof that this solution did work, the output of the last four lines (testing) gave:
cpu cpu cpu cuda:0 cuda:0 cuda:0