I’m doing some research on local minima and I’m trying to train a simple auto-encoder with only 1 weight matrix (W) and 1 bias (b) on a toy dataset, and found some very very peculiar behavior. At some epoch, the encoder is almost perfectly at a local optimum - I save this into a checkpoint. When I try to reload my checkpoint and continue training it, like so:
model = ToyModel(input_dim=7, hidden_dim=2, bias=True)
chkpt = torch.load(pth_file)
model.load_state_dict(chkpt['model_state_dict'])
The model continues training just fine, in that I see the weights updating. However, when I add in a simple load and save, like so:
model = ToyModel(input_dim=7, hidden_dim=2, bias=True)
# Load 1
chkpt = torch.load(pth_file)
model.load_state_dict(chkpt['model_state_dict'])
# Save as is
torch.save({
'epoch': -1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': SOME_SGD_OPTIMIZER,
'scheduler_state_dict': SOME_CONSTANT_LR_SCHEDULER,
'train_loss': 0,
'test_loss': 0,
}, save_loc)
# Load 2
chkpt = torch.load(save_loc)
model.load_state_dict(chkpt['model_state_dict'])
The model does not continue training just fine, in that the weights do not update anymore. This implies that the weights in the original chkpt
are different from the weights in save_loc
. What’s going on here?
- I don’t think it has to do with whether gradients are required (i.e. not a problem with calling
.detach()
, or turningrequires_grad
on or off, etc.), because when I try to modifymodel.W
between the# Load 1
andSave as is
steps, the training continuation proceeds fine, so it’s not a case of the code turning off gradient computation. - I think it has to do with the loss of some precision when calling
model.state_dict()
ortorch.save()
(or maybe evenmodel.load_state_dict()
), but I can’t seem to find any leads on what this is.
Help much appreciated!!