I’d like to double-check because I see multiple question and the official doc is confusing to me. I want to do:
- resume from a checkpoint to continue training on multiple gpus
- save checkpoint correctly during training with multiple gpus
For that my guess is the following:
- to do 1 we have all the processes load the checkpoint from the file, then call
DDP(mdl)
for each process. I assume the checkpoint saved addp_mdl.module.state_dict()
. - to do 2 simply check who is rank = 0 and have that one do the torch.save({‘model’: ddp_mdl.module.state_dict()})
Is this correct?
I am not sure why there are so many other posts (also this) or what the official doc is talking about. I don’t get why I’d want to optimize writing (or reading) once given that I will train my model from either 100 epochs to even 600K iterations…that one write seems useless to me.
I think the correct code is this one:
def save_ckpt(rank, ddp_model, path):
if rank == 0:
state = {'model': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(state, path)
def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
# loads to
checkpoint = torch.load(path, map_location=map_location)
model = Net(...)
optimizer = ...
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
if distributed:
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
return model
also, I added more details here: python - What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch? - Stack Overflow