Do model.train() conflict with pretrained model?

model_resnet = models.resnet18(pretrained=True)

if model is not None:
self.resnet_layer = nn.Sequential(*list(model.children())[:-1])
self.num_ftrs = self.model.fc.in_features
self.fc = nn.Linear(self.num_ftrs, 10)

Above is a part of my code, I wanna use pretrained resnet to train a new model by changing its last fully connected layer. During the training process, I don’t want any weight from resnet except the last fc layer change, I want to know, if I add a command like model.train() before the training process, will all the weights in the resnet be set requires_grad to True?

No, model.train() and model.eval() changes the behavior of certain layers, e.g. nn.BatchNorm and nn.Dropout. If you would like to freeze all layers but the last one, you would have to set their requires_grad flag to False:

model = models.resnet18(pretrained=True)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 10)

for name, param in model.named_parameters():
    if 'fc' not in name:
        print('Freezing {}'.format(name))
        print('Skipping {}'.format(name))    

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

# Check for gradients
output = model(torch.randn(2, 3, 224, 224))
loss = criterion(output, torch.empty(2, dtype=torch.long).random_(10))
print(model.fc.weight.grad)  # Should contain gradients
print(model.layer4[0].conv1.weight.grad)  # Should not contain gradients

1 Like

Thanks ptrblck, my confusion has been resolved.

1 Like