@apaszke I reference this PR for fine-tuning. For alexnet and vggnet, the original code replay all the fully-connected layers. May I ask:
- how can I only replace the last fully-connected layer for fine-tuning and freeze other fully-connected layers?
- Is the
forward
the right way to code? Because you give some reference code above:
def forward(self, x):
return self.last_layer(self.pretrained_model(x))
Original fine-tuing code:
class FineTuneModel(nn.Module):
def __init__(self, original_model, arch, num_classes):
super(FineTuneModel, self).__init__()
if arch.startswith('alexnet') :
self.features = original_model.features
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
self.modelName = 'alexnet'
elif arch.startswith('resnet') :
# Everything except the last linear layer
self.features = nn.Sequential(*list(original_model.children())[:-1])
self.classifier = nn.Sequential(
nn.Linear(512, num_classes)
)
self.modelName = 'resnet'
elif arch.startswith('vgg16'):
self.features = original_model.features
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(25088, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
self.modelName = 'vgg16'
else :
raise("Finetuning not supported on this architecture yet")
# Freeze those weights
for p in self.features.parameters():
p.requires_grad = False
def forward(self, x):
f = self.features(x)
if self.modelName == 'alexnet' :
f = f.view(f.size(0), 256 * 6 * 6)
elif self.modelName == 'vgg16':
f = f.view(f.size(0), -1)
elif self.modelName == 'resnet' :
f = f.view(f.size(0), -1)
y = self.classifier(f)
return y