Model parameters on multiple GPUs by DataParallel and DistributedDataParallel are the same. (unless GPU communication glitch happens)
You can just save parameters on GPU 0 and load them later.
- saving
if isinstance(model, (DataParallel, DistributedDataParallel)):
torch.save(model.module.state_dict(), model_save_name)
else:
torch.save(model.state_dict(), model_save_name)
- loading
state_dict = torch.load(model_name, map_location=current_gpu_device)
if isinstance(model, (DataParallel, DistributedDataParallel)):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict)