I have a subclass of torch.nn.Module for which I have multiple output heads, differing by one parameter that I pass into them. I want to have all of them accessible in the .children() function call, but with using a loop. In the code below, it seems only testClass3 will show any children at all (testClass1 does not). Is there a way to achieve this with loops?
import torch
class testClass2(torch.nn.Module):
def __init__(self, i):
super(testClass2, self).__init__()
self.i = i
class testClass1(torch.nn.Module):
def __init__(self):
super(testClass1, self).__init__()
self.test_classes = []
for i in range(5):
self.test_classes.append(testClass2(i))
self.checkParams()
def checkParams(self):
for children in self.children():
print(children)
class testClass3(torch.nn.Module):
def __init__(self):
super(testClass3, self).__init__()
self.test_classes = []
self.t1 = testClass2(1)
self.t2 = testClass2(2)
self.checkParams()
def checkParams(self):
for children in self.children():
print(children)
t = testClass1()
t = testClass3()
I’m not sure how torch.nn.Module determines its “children”, I’m guessing it just checks all of its parameters is they are an instance of torch.nn.Module or something like that.
Any insights appreciated.