Slicing mobilenet V3 matrix shape mistmatch

I have been trying to obtain the output of all but last linear layer of Mobilenet v3. The last module in small mobilnetv3 is a Sequential one:

print(mobilenet_v3_small.classifier)
Sequential(
  (0): Linear(in_features=576, out_features=1024, bias=True)
  (1): Hardswish()
  (2): Dropout(p=0.2, inplace=True)
  (3): Linear(in_features=1024, out_features=1000, bias=True)
)

I would like only the first Linear layer and all the preceding network. Here’s my attempt:

class mobilenetv3(nn.Module):
    def __init__(self):
        super().__init__()
        mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
        self.mobilenet_preclassifier = nn.Sequential(mobilenet_v3_small.features, 
                                                     mobilenet_v3_small.avgpool)
        self.mobilenet_linear = nn.Sequential(mobilenet_v3_small.classifier[:1])
        
    def forward(self, x):
        x = self.mobilenet_preclassifier(x)
        x = self.mobilenet_linear(x)
        return x

But I keep getting the following error:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (576x1 and 576x1024).

Even if I simply assemble the complete mobilenetv3 module by module (without slicing the classifier layer), the error remains the same. So, same error is seen if I do:

class mobilenetv3(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = nn.Sequential(mobilenet_v3_small.features, 
                               mobilenet_v3_small.avgpool, 
                               mobilenet_v3_small.classifier)
        
    def forward(self, x):
        x = self.m(x)
        return x

This points to the fact that I’m somehow assembling the modules incorrectly leading to the issue. Would be grateful if someone can point to the appropriate way to do it and maybe explain the problem with my approach. Thanks in advance.

Wrapping submodules into nn.Sequential containers often fails, as you might be missing functional API calls used in the forward method.
In your case, this torch.flatten operation is missing and will yield the shape mismatch.
You could either add it as an nn.Flatten to your nn.Sequential model or derive a custom class and override the forward method.

Thanks. That solved the issue.