Loading checkpoint from state_dict does NOT restore progress

Hi, I’m using DDP and attempting to resume training progress for my model.
However, Even after I restore the state_dict for the model it still trains from scratch rather than initializing from the checkpoint.

Code: Checkpoint Loading process

This is absolutely maddening me for quite some time now :frowning: I can confirm from the above specified print statement that the chkp_file is indeed the latest checkpoint available…

As you can see above, everything is quite nominal here - the only difference is that I use DDP (and fp16) but ensure I follow methods as outlined in the docs about saving/restoring checkpoints with DDP.

Does anyone have any idea why that may be an issue? Cheers.

As a guess, could you try:

model.module.load_state_dict(checkpoint['model_state_dict'], strict=True)

To explain, you are saving the state dict from the module wrapped by DDP (i.e. model.module), but you are loading it to the DDP instance itself (i.e. model), which is why my suggestion is to try loading it again to the wrapped module model.module.

Also, I wonder what happens if you only change strict=True in your current code. With strict=False, it may be that none of the parameters are actually being loaded, but this happens silently. The reason none would be loaded would be that none of the state dict keys match due to the model being loaded into having an additional module prefix.

1 Like

Thanks a ton mate! :tada: I thought I’d put the model.module (sorry for opening an issue with such a silly typo). I removed the strict=True override too, and it all works perfectly. Again, many thanks and sincerely hope you have an amazing week!!