Checkpoint in Multi GPU

Model parameters on multiple GPUs by DataParallel and DistributedDataParallel are the same. (unless GPU communication glitch happens)

You can just save parameters on GPU 0 and load them later.

  • saving
if isinstance(model, (DataParallel, DistributedDataParallel)):
    torch.save(model.module.state_dict(), model_save_name)
else:
    torch.save(model.state_dict(), model_save_name)
  • loading
state_dict = torch.load(model_name, map_location=current_gpu_device)
if isinstance(model, (DataParallel, DistributedDataParallel)):
    model.module.load_state_dict(state_dict)
else:
    model.load_state_dict(state_dict)