Hey there
I have to split the EfficientNet-b4 model w.r.t. it’s stages. In table 1 of the original paper, you can find the individual stages of the EfficientNet. I am curious whether my approach is correct or will lead to wrong results.
Here is a code snippet:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b4
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# Init efficientnet-b4 model and custom layers
self.efficientnetb4 = efficientnet_b4()
...
# Split efficientnet model into submodels w.r.t the stages 1...9 [1]
self.stage_one = nn.Sequential(
*list(self.efficientnetb4.features.children())[0])
self.stage_two = nn.Sequential(
*list(self.efficientnetb4.features.children())[1])
self.stage_three = nn.Sequential(
*list(self.efficientnetb4.features.children())[2])
self.stage_four = nn.Sequential(
*list(self.efficientnetb4.features.children())[3])
self.stage_five = nn.Sequential(
*list(self.efficientnetb4.features.children())[4])
self.stage_six = nn.Sequential(
*list(self.efficientnetb4.features.children())[5])
self.stage_seven = nn.Sequential(
*list(self.efficientnetb4.features.children())[6])
self.stage_eight = nn.Sequential(
*list(self.efficientnetb4.features.children())[7])
self.stage_nine = nn.Sequential(
*list(self.efficientnetb4.features.children())[8],
self.efficientnetb4.avgpool)
self.efficientnetb4.classifier = nn.Sequential(
nn.Dropout(p=0.5, inplace=True),
nn.Flatten(),
nn.Linear(in_features=1792, out_features=1300, bias=True),
nn.BatchNorm1d(1300),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(in_features=1300, out_features=650, bias=True),
nn.BatchNorm1d(650),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(in_features=650, out_features=325),
nn.BatchNorm1d(325),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(in_features=325, out_features=160, bias=True),
nn.BatchNorm1d(160),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(in_features=160, out_features=1))
Within the forward method of the EfficientNet-b4 implementation is a functional API call (torch.flatten). I am not sure, whether this will lead to wrong results. Perhaps someone can give me an answer.
Best regards