I got curious about how my model which is composed of several other models knows the parameters of all of them.
For instance, assume I had the following code:
class Net1(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(2, 2),
nn.ReLU(),
nn.Linear(2, 2),
nn.ReLU()
)
def forward(self, x):
return self.fc(x)
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(2, 8),
nn.ReLU(),
nn.Linear(8, 8),
nn.ReLU()
)
def forward(self, x):
return self.fc(x)
class Net3(nn.Module):
def __init__(self, net1, net2):
super().__init__()
self.net1 = net1
self.net2 = net2
self.fc = nn.Linear(8, 10)
def forward(self, x):
x = self.net1(x)
x = self.net2(x)
return self.fc(x)
If I instantiated them and then printed out net3’s parameters it would include the parameters of all three networks. This is great since it’s exactly what you want, but I can’t figure out how pytorch manages to do that internally.