I have a pretrained model that I would like to use on a testset. I do the following to check if the model did actually load the weights:
checkpoint = torch.load(checkpoint_path, map_location=device)
model = Model()
print(model.conv1.weight[0]) # (1) see the weights of the random initialisation
model.load_state_dict(checkpoint['model_state'], strict=False)
print(checkpoint['model_state']['module.conv1.weight'][0]) # (2) see the weights of the pretrained model
print(model.conv1.weight[0]) # I would expect this to be the same as (2) but it is the same as (1)
self.model.eval()
I maybe wrong but I think you are trying to load a DataParallel object into one that is a normal model. Can you please give more information on how you are saving the model and some idea about the model ? Also, can you try with the strict not as False?
The strict=False is needed since I load weights from a ResNet with one fully-connected layer into a ResNet with a different fully-connected layer in the end.
So I made the following changes and now it seems to work:
checkpoint = torch.load(checkpoint_path, map_location=device)
model = Model()
model = nn.DataParallel(model, device_ids=device_ids)
model.to(self.device)
print(model.module.conv1.weight[0]) # shows random weights
model.load_state_dict(checkpoint['model_state'], strict=False)
print(checkpoint['model_state']['module.conv1.weight'][0]) # shows pretrained weights of checkpoint
print(model.module.conv1.weight[0]) # shows pretrained weights in model
The problem was that I had to initialise my model as a DataParallel object and additionally call .module when accessing the weights of a specific layer.