Define model components when building computation graph?

Hi, I noticed that when we define a pytorch model, we usually need to specify its components before applying forward() function. An example is this:

class Net(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

here we need to defined self.conv1, self.conv2, etc… While in keras, we usually directly define these components as we build the computation graph, e.g.

encoder_input = keras.Input(shape=(28, 28, 1), name="img")
x = layers.Conv2D(16, 3, activation="relu")(encoder_input)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.Conv2D(16, 3, activation="relu")(x)

Here we do not need to define several model.Conv2D in advance, but instead define them as we build the graph.

I am wondering if there are similar way to do this in PyTorch? I am asking this since:

  1. sometimes it could be error-prone to keep track of what components I have defined in __init__() and what components are being used in forward()
  2. when the number of layers are large, it may not be quite convenient to label all layers with unique index.

If anyone has idea, please let me know, thanks!

1 Like