Split network into backbone and head

This is my neural network relevant code snippet:

class NeuralNetwork(nn.Module):
    def __init__(self, input_size:int):
        self.backbone = SpotRotBackbone()
        # input to the head calculated from the  backbone's structure 
        self.head = SpotRotHead(input_size*input_size*16)
    def forward(self, x):
        x = self.backbone(x)
        return self.head(x)

Now, the reason why I am using this line:

self.head = SpotRotHead(input_size*input_size*16)

Is because I know what the backbone does.

Another way I found, much more general but still ugly, is this one:

    def __init__(self, input_size:int):
        self.backbone = SpotRotBackbone()
        # can't find a better way
        output_dimensions = self.backbone.forward(torch.zeros(1,3,input_size,input_size)).shape
        self.head = SpotRotHead(output_dimensions[-1])

What do you think ? How would you do it?

If you cannot create a fixed activation shape after the backbone, e.g. via adaptive pooling layers, both approaches are valid. Pre-computing the input shape or using a fake forward pass are valid approaches. For the latter case you should make sure the actual forward pass using torch.zeros does not change anything inside the self.backbone (e.g. does not update running stats from batchnorm layers etc.).

1 Like

What do you mean @ptrblck ?

The model’s head looks like this:

class SpotRotHead(nn.Module):
    Classification of the image rotation angle.
        input_size: the size of the input image

    def __init__(self, input_size: int):
        self.fc1 = nn.Linear(input_size, 16)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(16, 4)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        return x

Before feeding an activation to the first linear layer, adaptive pooling layers are often used, which allow you to create a defined output shape. If this doesn’t fit your use case, your approaches looks valid.

1 Like

Sounds quite interesting. Do you have any link or code sample so I can see this?

Almost all CNNs use it as seen here for resnet.

1 Like

thank you! a complementary link is

So if I get the point this always produce the same output size, which is quite interesting, although I wonder whether using an AdaptiveAvgPool wouldn’t worsen the performance significantly

An aggressive pooling could influence the accuracy and you would need to test it. However, note that it also comes down to allowing users to use a variable input shape vs. raising a shape mismatch error. A reduction in accuracy might thus be the more user-friendly approach instead of a runtime error.

1 Like