Error(s) in loading state_dict for ResNet

Hi,
Could someone help me with this issue as I am unable to get this right.

Created a model like so
model = torchvision.models.resnet50(pretrained=True)

Inside for loop of each epoch , I have saved the state_dict like so

state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict() }
 torch.save(state, FILE_NAME)

After, a crash, I decided to load the model but unable to get a handle on how to load the model.
I did this:
model = torchvision.models.resnet50(pretrained=True)
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

The Error I am getting is below:

RuntimeError: Error(s) in loading state_dict for ResNet:
	size mismatch for fc.weight: copying a param with shape torch.Size([18425, 2048]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).
	size mismatch for fc.bias: copying a param with shape torch.Size([18425]) from checkpoint, the shape in current model is torch.Size([1000]).

Did you modify the resnet50 from torchvision?

According to your error, the output of the last fc layer is 18452 dim, the output dim in original resnet50 is 1000. I think you should check the shape of the weight of fc layer and find where the mismatch error occurs.

P.S. please format your code with two ```

Didn’t modify Resnet. Thanks I will see the shape.

First of all, welcome to the PyTorch Community!!
Coming to your question,

I am pretty sure, when you were training it for the first time you somehow tweaked the resnet50 that comes with torchvision.

From what it seems, you have changed the final layer with (most probaby for transfer learning purposes)

fc.weight = nn.Linear(2048, 18425)

Probably you have not intended to do this. as I think you have calculated the 18425 and there was some bug in the code that resulted in 18425.

As for now, if you need to load this ckpt anyway, just load which ever parts are okay and manually take out the weights of fc.weight in the ckpt and take out a slice.

fc.weight[:, nc]

where nc is the number of classes you want. Store all these weights in a new_dict and then

sd = model.state_dict()
sd.update(new_dict)
model.load_state_dict(sd)

Make sure new_dict and sd have the same keys.

1 Like

I will try the above thanks .
Looks like the same error was solved using this method

model = resnet18(pretrained=True, num_classes=4)  # where num_classes will be different 
model.load_state_dict(checkpoint)  # load

In general, if you have a different number of output classes than the original network, the pretrained model is not going to work as you expect. You either have fewer classes, but they don’t mean the same thing as the original, or you have more classes, and the additional weights are simply going to be automatically generated (if you specified an initialization or use the default initialization).

This is why fine-tuning is often required. See this official tutorial for detailed explanations.

Thanks Alex , yes it doesn't work. So,this time around I saved the model state_dict and also the full model.

As I started training again, I had to stop in the middle because of power cut. I just wanted to resume training.
This time around, the load_state_dict still had problems with the size but loading the full model I was able to resume training with the current loss.

Here it is, when I started the GAP score was 0.0000
1. epoch 1
total batches: 5024
1 [0/5024]      time 237.078 (237.078)  loss 9.9302 (9.9302)    GAP 0.0000 (0.0000)

After resuming, loading the full model , the GAP score is below.
2. epoch 1
total batches: 5024
1 [0/5024]      time 68.160 (68.160)    loss 5.0118 (5.0118)    GAP 0.0345 (0.0345)
So, the solution to loading state_dict when using pretrained model is to avoid copying fc.weights and fc.bias

avoid = ['fc.weight', 'fc.bias']
for key  in pytorch_state_dict.keys():
        if key in avoid or key not in state_dict.keys(): # 
            # print('not in', key)
            continue
        if pytorch_state_dict[key].size() != state_dict[key].size():
            # print('size not the same', key)
            continue
        state_dict[key] = pytorch_state_dict[key]

Thanks to Arunava Chakraborty also for his advice

1 Like