Split EfficientNet model in to its individual stages

Hey there :slight_smile:

I have to split the EfficientNet-b4 model w.r.t. it’s stages. In table 1 of the original paper, you can find the individual stages of the EfficientNet. I am curious whether my approach is correct or will lead to wrong results.

Here is a code snippet:

import torch
import torch.nn as nn
from torchvision.models import efficientnet_b4

class Net(nn.Module):
    def __init__(self):
      
        super(Net, self).__init__()
        # Init efficientnet-b4 model and custom layers
        self.efficientnetb4 = efficientnet_b4()
        ...
        # Split efficientnet model into submodels w.r.t the stages 1...9 [1]
        self.stage_one = nn.Sequential(
            *list(self.efficientnetb4.features.children())[0])
        self.stage_two = nn.Sequential(
            *list(self.efficientnetb4.features.children())[1])
        self.stage_three = nn.Sequential(
            *list(self.efficientnetb4.features.children())[2])
        self.stage_four = nn.Sequential(
            *list(self.efficientnetb4.features.children())[3])
        self.stage_five = nn.Sequential(
            *list(self.efficientnetb4.features.children())[4])
        self.stage_six = nn.Sequential(
            *list(self.efficientnetb4.features.children())[5])
        self.stage_seven = nn.Sequential(
            *list(self.efficientnetb4.features.children())[6])
        self.stage_eight = nn.Sequential(
            *list(self.efficientnetb4.features.children())[7])
        self.stage_nine = nn.Sequential(
            *list(self.efficientnetb4.features.children())[8],
            self.efficientnetb4.avgpool)
        self.efficientnetb4.classifier = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Flatten(),
            nn.Linear(in_features=1792, out_features=1300, bias=True),
            nn.BatchNorm1d(1300),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(in_features=1300, out_features=650, bias=True),
            nn.BatchNorm1d(650),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(in_features=650, out_features=325),
            nn.BatchNorm1d(325),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(in_features=325, out_features=160, bias=True),
            nn.BatchNorm1d(160),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(in_features=160, out_features=1))

Within the forward method of the EfficientNet-b4 implementation is a functional API call (torch.flatten). I am not sure, whether this will lead to wrong results. Perhaps someone can give me an answer.

Best regards

Missing functional API calls sound wrong and would change the output.
I would thus recommend setting the model into .eval() mode and comparing its outputs with the original model to see if anything else causes numerical mismatches.

Hey ptrblck,
thanks for your response.

I had to initialize both networks with pre-trained weights to ensure that they have the same weights. After adding the missed functional API call between stage nine and the classifier step, both networks produce the same output in evaluation mode.

Best regards