Using variable number of modules as children when subclassing torch.nn.Module

I have a subclass of torch.nn.Module for which I have multiple output heads, differing by one parameter that I pass into them. I want to have all of them accessible in the .children() function call, but with using a loop. In the code below, it seems only testClass3 will show any children at all (testClass1 does not). Is there a way to achieve this with loops?

import torch


class testClass2(torch.nn.Module):
    def __init__(self, i):
        super(testClass2, self).__init__()
        self.i = i


class testClass1(torch.nn.Module):
    def __init__(self):
        super(testClass1, self).__init__()
        self.test_classes = []
        for i in range(5):
            self.test_classes.append(testClass2(i))
        self.checkParams()

    def checkParams(self):
        for children in self.children():
            print(children)

class testClass3(torch.nn.Module):
    def __init__(self):
        super(testClass3, self).__init__()
        self.test_classes = []
        self.t1 = testClass2(1)
        self.t2 = testClass2(2)
        self.checkParams()

    def checkParams(self):
        for children in self.children():
            print(children)

t = testClass1()
t = testClass3()

I’m not sure how torch.nn.Module determines its “children”, I’m guessing it just checks all of its parameters is they are an instance of torch.nn.Module or something like that.
Any insights appreciated.

PyTorch will register the submodules as children, if you use an nn.ModuleList instead of a Python list.
Use self.test_classes = nn.ModuleList() and it should work.

1 Like