Check weights after laoding state_dict

Hi everyone :slight_smile:

I have a pretrained model that I would like to use on a testset. I do the following to check if the model did actually load the weights:

checkpoint = torch.load(checkpoint_path, map_location=device) 
model = Model()
print(model.conv1.weight[0]) # (1) see the weights of the random initialisation

model.load_state_dict(checkpoint['model_state'], strict=False)
print(checkpoint['model_state']['module.conv1.weight'][0]) # (2) see the weights of the pretrained model
print(model.conv1.weight[0]) # I would expect this to be the same as (2) but it is the same as (1)


What am I doing wrong?

Any help is very much appreciated!

All the best,

I maybe wrong but I think you are trying to load a DataParallel object into one that is a normal model. Can you please give more information on how you are saving the model and some idea about the model ? Also, can you try with the strict not as False?

You are right, the checkpoint file is from a DataParallel object. Does that mean I have to make one again and then load the weights?

I save like this:

checkpoint = {
                'run': self.run_count,
                'epoch': self.epoch_count,
                'model_state': self.model.state_dict(),
                'optimizer_state': self.optimizer.state_dict(),
                'accuracy_val': self.accuracy_val
            }, 'filename.pth')

The strict=False is needed since I load weights from a ResNet with one fully-connected layer into a ResNet with a different fully-connected layer in the end.

So I made the following changes and now it seems to work:

checkpoint = torch.load(checkpoint_path, map_location=device) 
model = Model()
model = nn.DataParallel(model, device_ids=device_ids) 
print(model.module.conv1.weight[0]) # shows random weights

model.load_state_dict(checkpoint['model_state'], strict=False)
print(checkpoint['model_state']['module.conv1.weight'][0]) # shows pretrained weights of checkpoint

print(model.module.conv1.weight[0]) # shows pretrained weights in model

The problem was that I had to initialise my model as a DataParallel object and additionally call .module when accessing the weights of a specific layer.