EfficientNet Error - RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1000 and 1408x512)

I am trying to train an Efficientnet model in PyTorch. Below is the code I have written for the model. In the following code I get the error "RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1000 and 1408x512)

backbone_model = EfficientNet.from_pretrained('efficientnet-b2', include_top=False)
x = torch.randn(1,3,150,150)
features = backbone_model(x)
num_channels = features.shape[1]
print(num_channels, "num_channels")

num_classes = len(CLASSES)

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.backbone_model = models.efficientnet_b2(pretrained=True)
        self.backbone_model._avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc1 = nn.Linear(1408, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc_out = nn.Linear(256, num_classes)
        self.dropout = nn.Dropout(0.2)
    def forward(self, x):
        x = self.backbone_model(x)
        print(x.shape, "shape1")
        x = torch.flatten(x, 1)
        print(x.shape, "shape2")
        print("fc1 input size", self.fc1.in_features)
        x = self.dropout(x)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = torch.relu(x)
        x = self.fc_out(x) 

print(num_channels, “num_channels”) → 1408 num_channels

print(x.shape, “shape1”) → torch.Size([4, 1000]) shape1

print(x.shape, “shape2”) → torch.Size([4, 1000]) shape2

print(“fc1 input size”, self.fc1.in_features) → fc1 input size 1408

How can I solve this error?

From the looks of it, I would try to change the self.fc1 = self.fc1 = nn.Linear(1408, 512) to self.fc1 = nn.Linear(1000, 512). This is just an intuition.


Upon further investigation, I found out that the classifier block of the EfficientNet is still set to 1000 as its output shape and the two fully connected layers that you used are just getting built on top of it. That is to say, the classifier output is still there.