I want to resume training and try to load the state_dict of an Adam optimizer, however I find that the loss will increase for a few hundred iterations before going down again.
Later I found that the self.state in torch.optim.Adam is always empty which causes state initialization for each optimizer.step() call. Shall we modify the Adam code by adding self.state[p] = state in the end of the loop?
I’m using ignite for training and its ModelCheckpoint for saving model’s and optimizer’s state_dict(). I found all the saved optimizers’ self.state is an empty dict().
If I load the state_dict after re-initializing the optimizer, optimizer.state is set again:
# Create dummy model and optimizer
model = nn.Linear(10, 20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Create dummy backward pass
out = model(torch.randn(1, 10))
out.mean().backward()
optimizer.step()
# Check state and store state_dict
print(optimizer.state)
state_dict = optimizer.state_dict()
# Re-initialize and check state
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print(optimizer.state)
# Load and check state
optimizer.load_state_dict(state_dict)
print(optimizer.state)
Could you post the code you are using to store and load the optimizer?
It seems that there is a bug in my training code. I shouldn’t pass obj.state_dict() to the ModelCheckpoint:
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()})
Instead, I should pass the obj itself:
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
'optimizer': optimizer,
'scheduler': scheduler})
The model.state_dict() works in both situations, but for optimizer and scheduler only the initial state will be saved. The reason might be the different implementations of state_dict() in nn.Module, optim.Optimizer and optim._LRScheduler.
Hi ptrblck,
I tried to output opt.state but it is always empty as this
defaultdict(<class 'dict'>, {})
What dose state used for?
1 Like
Hi @ptrblck,
TL;DR - Can you save optim.state to load in after a checkpoint? (Instead of it being re-initialized)
Apologizes for opening this post again (after 4 years), but is there a way to save the optim.state when saving the state_dict? Instead of it being re-initialized after re-initializing the class?
The case I have in mind is that, if I train some model with a custom optimizer that uses self.state in some way and the model is interrupted and has to resume from a checkpoint/torch.save file. The behavior won’t be consistent between a model, which ran without interruption and the one that did.
Yes, you can store the optimizer.state_dict() and load it via optimizer.load_state_dict(state_dict) as is also done for the model.
1 Like