I have 2 networks with some common layers. So is the following code correct?
class CommonNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3)
self.conv2 = nn.Conv2d(32, 64, 3)
class Net1(CommonNet):
def __init__(self):
super().__init__()
self.final_conv = nn.Conv2d(64, 4, 3)
def forward(self, x):
out = self.conv1(x)
out = self.covn2(out)
out = self.final_conv(out)
return out
class Net2(CommonNet):
def __init__(self):
super().__init__()
self.final_conv = nn.Conv2d(64, 2, 3)
def forward(self, x):
out = self.conv1(x)
out = self.covn2(out)
out = self.final_conv(out)
return out