Hi, I’m training a model using transfer learning on this dataset: pizza_steak_sushi, and when trying to update the out_features
directly, performance is really reduced. Example below
I am using EfficientNet_B0
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
model = torchvision.models.efficientnet_b0(weights=weights)
The original classifier layer has out_features=1000
and I would like out_features=3
This is the output of model.classifier
:
Sequential(
(0): Dropout(p=0.2, inplace=True)
(1): Linear(in_features=1280, out_features=1000, bias=True)
)
Freezing:
for param in model.features.parameters():
param.requires_grad = False
Now updating the out_features
, if I try to update it directly like so:
model.classifier[1].out_features = 3
compared with:
model.classifier = torch.nn.Sequential(
nn.Dropout(p=0.2, inplace=True),
nn.Linear(in_features=1280, out_features=3, bias=True).to(device))
Even though they both yield the same result:
Sequential(
(0): Dropout(p=0.2, inplace=True)
(1): Linear(in_features=1280, out_features=3, bias=True)
)
The former performs poorly (training for 5 epochs):
Epoch: 1 | train_loss: 6.4254 | train_acc: 0.0234 | test_loss: 6.0019 | test_acc: 0.0000
Epoch: 2 | train_loss: 4.0430 | train_acc: 0.2852 | test_loss: 3.5201 | test_acc: 0.1752
Epoch: 3 | train_loss: 2.6871 | train_acc: 0.2539 | test_loss: 2.4525 | test_acc: 0.2585
Epoch: 4 | train_loss: 1.9191 | train_acc: 0.3359 | test_loss: 2.1126 | test_acc: 0.3002
Epoch: 5 | train_loss: 1.6082 | train_acc: 0.3906 | test_loss: 1.9565 | test_acc: 0.3523
Compared with the later:
Epoch: 1 | train_loss: 1.0294 | train_acc: 0.4453 | test_loss: 0.8789 | test_acc: 0.6420
Epoch: 2 | train_loss: 0.9224 | train_acc: 0.6445 | test_loss: 0.6745 | test_acc: 0.9062
Epoch: 3 | train_loss: 0.7471 | train_acc: 0.8516 | test_loss: 0.6538 | test_acc: 0.8655
Epoch: 4 | train_loss: 0.6976 | train_acc: 0.7500 | test_loss: 0.6515 | test_acc: 0.8456
Epoch: 5 | train_loss: 0.6853 | train_acc: 0.7734 | test_loss: 0.6107 | test_acc: 0.8665
Can anyone explain why this is?
Many thanks