Do I have to inherit nn.Module?

I have 2 networks with some common layers. So is the following code correct?

class CommonNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 32, 3)
    self.conv2 = nn.Conv2d(32, 64, 3)

class Net1(CommonNet):
  def __init__(self):
    super().__init__()
    self.final_conv = nn.Conv2d(64, 4, 3)
  def forward(self, x):
    out = self.conv1(x)
    out = self.covn2(out)
    out = self.final_conv(out)
    return out

class Net2(CommonNet):
  def __init__(self):
    super().__init__()
    self.final_conv = nn.Conv2d(64, 2, 3)  
  def forward(self, x):
    out = self.conv1(x)
    out = self.covn2(out)
    out = self.final_conv(out)
    return out

You need to call super in commont net but the code is correct (there are typos).

1 Like