If anyone has the same issue, the two things I did to fix the problem were:
-
Restore original code
model.load_state_dict(checkpoint['state_dict'])
as suggested by @ptrblck above. -
In the epoch loop just before calling
save_checkpoint
I added:
try:
model_state_dict = model.module.state_dict()
except AttributeError:
model_state_dict = model.state_dict()
Based on suggestion by @alex.veuthey here