I also am facing the same issue. I don’t use nn.DataParallel.
I am using a slightly modified version of [this repo] in a Kaggle notebook https://github.com/aitorzip/PyTorch-CycleGAN.
Here’s how I save:
torch.save({
'epoch': epoch,
'model_state_dict': netG_A2B.state_dict(),
'optimizer_state_dict': optimizer_G.state_dict(),
'loss_histories': loss_histories,
}, f'netG_A2B_{epoch:03d}.pth')
Here’s how I load:
checkpoint = torch.load(weights_path)['model_state_dict']
self.model.load_state_dict(checkpoint)
And here’s a representative snippet of the (longer) error message:
Missing key(s) in state_dict: "1.weight", "1.bias", "4.weight", "4.bias", "7.weight", "7.bias".
Unexpected key(s) in state_dict: "model.1.weight", "model.1.bias", "model.4.weight", "model.4.bias", "model.7.weight", "model.7.bias".
I opted for this method to fix it:
checkpoint = torch.load(weights_path, map_location=self.device)['model_state_dict']
for key in list(checkpoint.keys()):
if 'model.' in key:
checkpoint[key.replace('model.', '')] = checkpoint[key]
del checkpoint[key]
self.model.load_state_dict(checkpoint)