@ptrblck Model is saved as DDP module and I get this error
if (gpu0 == 0 and epoch == 4):
model.cpu()
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
}, epoch)
class Net(nn.Module):
def __init__(self, gpu0, gpu1):
super(Net, self).__init__()
if gpu0 != "cpu":
self.gpu0 = "cuda:"+str(gpu0)
self.gpu1 = "cuda:"+str(gpu1)
else:
self.gpu0 = "cpu"
self.gpu1 = "cpu"
self.conv = nn.Sequential( nn.Conv2d(1, 32, 3, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
nn.MaxPool2d(2), nn.Dropout(0.25), nn.Flatten(1),
).to(self.gpu0)
self.feat = nn.Sequential(nn.Linear(9216, 128), nn.BatchNorm1d(128),
nn.ReLU(), nn.Dropout2d(0.5), nn.Linear(128, 10)
).to(self.gpu1)
def forward(self, x):
x = self.conv(x).to(self.gpu1)
x = self.feat(x)
output = F.log_softmax(x, dim=1)
return output
model = Net("cpu", "cpu")
PATH = '../model_4.pth'
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
Error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-ec88a5c31813> in <module>
2 PATH = '../model_4.pth'
3 checkpoint = torch.load(PATH)
----> 4 model.load_state_dict(checkpoint['state_dict'])
5 #optimizer.load_state_dict(checkpoint['optimizer'])
~/.conda/envs/praveen_tf/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1043 if len(error_msgs) > 0:
1044 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1045 self.__class__.__name__, "\n\t".join(error_msgs)))
1046 return _IncompatibleKeys(missing_keys, unexpected_keys)
1047
RuntimeError: Error(s) in loading state_dict for Net:
Missing key(s) in state_dict: "conv.0.weight", "conv.0.bias", "conv.2.weight", "conv.2.bias", "feat.0.weight", "feat.0.bias", "feat.1.weight", "feat.1.bias", "feat.1.running_mean", "feat.1.running_var", "feat.4.weight", "feat.4.bias".
Unexpected key(s) in state_dict: "module.conv.0.weight", "module.conv.0.bias", "module.conv.2.weight", "module.conv.2.bias", "module.feat.0.weight", "module.feat.0.bias", "module.feat.1.weight", "module.feat.1.bias", "module.feat.1.running_mean", "module.feat.1.running_var", "module.feat.1.num_batches_tracked", "module.feat.4.weight", "module.feat.4.bias".