There are a few things to clarify.
- As you are using the
resnet18from torchvision, the model only lives on a single GPU. - The launcher script you use starts
num_gpusprocesses, and each process has its own DDP instance, dataloader, and the model replica. - With 1 and 2, your training scripts only need put the model to one GPU (you can use the rank as the device id), load the data into one GPU, and the DDP instance will handle the comm for you, and make sure that all model replicas are synchronized properly.
- With the above 3, the question then would be “how do I load a model to a specific GPU device?”. And the answer is use
map_local=torch.device(rank).
The following code works for me with the launching cmd
python -m torch.distributed.launch --nproc_per_node=2 test.py
import argparse
from torchvision.models import resnet18
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch
def cleanup():
dist.destroy_process_group()
def main(args):
torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=args.local_rank, world_size=2)
torch.cuda.set_device(args.local_rank)
model = resnet18()
path = "save_model.pt"
if args.local_rank == 0:
# save CPU model
torch.save(model, path)
dist.barrier()
# local model to GPU
loaded_model = torch.load(path, map_location=torch.device(args.local_rank))
model = DDP(loaded_model, device_ids=[args.local_rank])
print(f"Rank {args.local_rank} traning on device {list(model.parameters())[0].device}")
# create a dedicated data loader for each process
cleanup()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="blah")
parser.add_argument("--local_rank", type=int)
args, _ = parser.parse_known_args()
main(args)