When checkpointing, is it ok to save ddp_model.module instead of ddp_model? I need to be able to use the checkpoint for 1. evaluation using a single gpu 2. resume training with multiple gpus
def save_ckpt(ddp_model, path):
state = {'model': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
torch.save(state, path)
def load_ckpt(path, distributed):
checkpoint = torch.load(path, map_location=map_location)
model = Net(...)
optimizer = ...
model.load_state_dict(checkpoint['model'], strict=False)
if distributed:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
return model
I am not sure how this would behave, does manipulating ddp_model.module in general break things?
Similarly what if the number of gpus changes, would that impact the optimizer for example? does it need to be reinitialized? Thank you!
On save, use one rank to save ddp_model.module to checkpoint.
On load, first use the checkpoint to load a local model, and then wrap the local model instance with DDP on all processes (i.e., sth like DistributedDataParallel(local_model, device_ids=[rank])).
I’d like to double-check because I see multiple question and the official doc is confusing to me. I want to do:
resume from a checkpoint to continue training on multiple gpus
save checkpoint correctly during training with multiple gpus
For that my guess is the following:
to do 1 we have all the processes load the checkpoint from the file, then call DDP(mdl) for each process. I assume the checkpoint saved a ddp_mdl.module.state_dict().
to do 2 simply check who is rank = 0 and have that one do the torch.save({‘model’: ddp_mdl.module.state_dict()})
Is this correct?
I am not sure why there are so many other posts (also this) or what the official doc is talking about. I don’t get why I’d want to optimize writing (or reading) once given that I will train my model from either 100 epochs to even 600K iterations…that one write seems useless to me.
I think the correct code is this one:
def save_ckpt(rank, ddp_model, path):
if rank == 0:
state = {'model': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
torch.save(state, path)
def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
# loads to
checkpoint = torch.load(path, map_location=map_location)
model = Net(...)
optimizer = ...
if distributed:
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
return model