Hi,
Could someone help me with this issue as I am unable to get this right.
Created a model like so
model = torchvision.models.resnet50(pretrained=True)
Inside for loop of each epoch , I have saved the state_dict like so
state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict() }
torch.save(state, FILE_NAME)
After, a crash, I decided to load the model but unable to get a handle on how to load the model.
I did this:
model = torchvision.models.resnet50(pretrained=True)
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
The Error I am getting is below:
RuntimeError: Error(s) in loading state_dict for ResNet:
size mismatch for fc.weight: copying a param with shape torch.Size([18425, 2048]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).
size mismatch for fc.bias: copying a param with shape torch.Size([18425]) from checkpoint, the shape in current model is torch.Size([1000]).
Did you modify the resnet50 from torchvision?
According to your error, the output of the last fc layer is 18452 dim, the output dim in original resnet50 is 1000. I think you should check the shape of the weight of fc layer and find where the mismatch error occurs.
P.S. please format your code with two ```
Didn’t modify Resnet. Thanks I will see the shape.
First of all, welcome to the PyTorch Community!!
Coming to your question,
I am pretty sure, when you were training it for the first time you somehow tweaked the resnet50
that comes with torchvision
.
From what it seems, you have changed the final layer with (most probaby for transfer learning purposes)
fc.weight = nn.Linear(2048, 18425)
Probably you have not intended to do this. as I think you have calculated the 18425
and there was some bug in the code that resulted in 18425
.
As for now, if you need to load this ckpt anyway, just load which ever parts are okay and manually take out the weights of fc.weight
in the ckpt and take out a slice.
fc.weight[:, nc]
where nc is the number of classes you want. Store all these weights in a new_dict
and then
sd = model.state_dict()
sd.update(new_dict)
model.load_state_dict(sd)
Make sure new_dict
and sd
have the same keys.
I will try the above thanks .
Looks like the same error was solved using this method
model = resnet18(pretrained=True, num_classes=4) # where num_classes will be different
model.load_state_dict(checkpoint) # load
In general, if you have a different number of output classes than the original network, the pretrained model is not going to work as you expect. You either have fewer classes, but they don’t mean the same thing as the original, or you have more classes, and the additional weights are simply going to be automatically generated (if you specified an initialization or use the default initialization).
This is why fine-tuning is often required. See this official tutorial for detailed explanations.
Thanks Alex , yes it doesn't work. So,this time around I saved the model state_dict and also the full model.
As I started training again, I had to stop in the middle because of power cut. I just wanted to resume training.
This time around, the load_state_dict still had problems with the size but loading the full model I was able to resume training with the current loss.
Here it is, when I started the GAP score was 0.0000
1. epoch 1
total batches: 5024
1 [0/5024] time 237.078 (237.078) loss 9.9302 (9.9302) GAP 0.0000 (0.0000)
After resuming, loading the full model , the GAP score is below.
2. epoch 1
total batches: 5024
1 [0/5024] time 68.160 (68.160) loss 5.0118 (5.0118) GAP 0.0345 (0.0345)
So, the solution to loading state_dict when using pretrained model is to avoid copying fc.weights and fc.bias
avoid = ['fc.weight', 'fc.bias']
for key in pytorch_state_dict.keys():
if key in avoid or key not in state_dict.keys(): #
# print('not in', key)
continue
if pytorch_state_dict[key].size() != state_dict[key].size():
# print('size not the same', key)
continue
state_dict[key] = pytorch_state_dict[key]
Thanks to Arunava Chakraborty also for his advice