I saved my network as
checkpoint = {'epoch': epoch, 'loss': loss, 'model_state_dict': network.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict()}
torch.save(checkpoint, 'check-point.pth')
but was able to recover it using
with open('check-point.pth', 'rb') as f:
checkpoint = torch.load(f)
network.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
though I had to do nn.DataParallel(network)
afterwards. Shouldn’t it through an error though since to recover the model network.module.load_state_dict()
should have been used?