I want to resume training and try to load the state_dict of an Adam optimizer, however I find that the loss will increase for a few hundred iterations before going down again.
Later I found that the self.state
in torch.optim.Adam is always empty which causes state initialization for each optimizer.step() call. Shall we modify the Adam code by adding self.state[p] = state
in the end of the loop?
I’m using ignite for training and its ModelCheckpoint for saving model’s and optimizer’s state_dict()
. I found all the saved optimizers’ self.state
is an empty dict().
If I load the state_dict
after re-initializing the optimizer, optimizer.state
is set again:
# Create dummy model and optimizer
model = nn.Linear(10, 20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Create dummy backward pass
out = model(torch.randn(1, 10))
out.mean().backward()
optimizer.step()
# Check state and store state_dict
print(optimizer.state)
state_dict = optimizer.state_dict()
# Re-initialize and check state
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print(optimizer.state)
# Load and check state
optimizer.load_state_dict(state_dict)
print(optimizer.state)
Could you post the code you are using to store and load the optimizer?
It seems that there is a bug in my training code. I shouldn’t pass obj.state_dict()
to the ModelCheckpoint
:
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()})
Instead, I should pass the obj
itself:
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
'optimizer': optimizer,
'scheduler': scheduler})
The model.state_dict()
works in both situations, but for optimizer and scheduler only the initial state will be saved. The reason might be the different implementations of state_dict()
in nn.Module
, optim.Optimizer
and optim._LRScheduler
.
Hi ptrblck,
I tried to output opt.state
but it is always empty as this
defaultdict(<class 'dict'>, {})
What dose state
used for?
1 Like
Hi @ptrblck,
TL;DR - Can you save optim.state
to load in after a checkpoint? (Instead of it being re-initialized)
Apologizes for opening this post again (after 4 years), but is there a way to save the optim.state
when saving the state_dict
? Instead of it being re-initialized after re-initializing the class?
The case I have in mind is that, if I train some model with a custom optimizer that uses self.state
in some way and the model is interrupted and has to resume from a checkpoint/torch.save
file. The behavior won’t be consistent between a model, which ran without interruption and the one that did.
Yes, you can store the optimizer.state_dict()
and load it via optimizer.load_state_dict(state_dict)
as is also done for the model.
1 Like