Loading and saving a checkpoint changes it?

I’m doing some research on local minima and I’m trying to train a simple auto-encoder with only 1 weight matrix (W) and 1 bias (b) on a toy dataset, and found some very very peculiar behavior. At some epoch, the encoder is almost perfectly at a local optimum - I save this into a checkpoint. When I try to reload my checkpoint and continue training it, like so:

model = ToyModel(input_dim=7, hidden_dim=2, bias=True)
chkpt = torch.load(pth_file)
model.load_state_dict(chkpt['model_state_dict'])

The model continues training just fine, in that I see the weights updating. However, when I add in a simple load and save, like so:

model = ToyModel(input_dim=7, hidden_dim=2, bias=True)

# Load 1
chkpt = torch.load(pth_file)
model.load_state_dict(chkpt['model_state_dict'])

# Save as is
torch.save({
    'epoch': -1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': SOME_SGD_OPTIMIZER,
    'scheduler_state_dict': SOME_CONSTANT_LR_SCHEDULER,
    'train_loss': 0,
    'test_loss': 0,
}, save_loc)

# Load 2
chkpt = torch.load(save_loc)
model.load_state_dict(chkpt['model_state_dict'])

The model does not continue training just fine, in that the weights do not update anymore. This implies that the weights in the original chkpt are different from the weights in save_loc. What’s going on here?

  • I don’t think it has to do with whether gradients are required (i.e. not a problem with calling .detach(), or turning requires_grad on or off, etc.), because when I try to modify model.W between the # Load 1 and Save as is steps, the training continuation proceeds fine, so it’s not a case of the code turning off gradient computation.
  • I think it has to do with the loss of some precision when calling model.state_dict() or torch.save() (or maybe even model.load_state_dict()), but I can’t seem to find any leads on what this is.

Help much appreciated!!

Woops, I found the problem was due to my newly created scheduler being incompatible with the optimizer, and somewhere there was a lr key in the scheduler that was 0.0 and/or there was aggressive weight decay and I didn’t provide code to illustrate that.

Not a pytorch problem, case closed