Hi , I’m having trouble loading the distributed dataparallel model to just 1 GPU. And I want to know how to load the model (trained by 4 GPU with distributed dataparallel) to another job using only 1 GPU.
I have trained a model using 4 GPU and distributed dataparallel, and I saved it as the tutorial:
https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints
However, I don’t know how to load it using just 1 GPU for some simple job like validation test.
if rank == 0:
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
dist.barrier()
I’m now using this method:
# initialize
torch.distributed.init_process_group(backend="nccl")
local_rank = torch.distributed.get_rank()
print(local_rank)
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
print(device)
# only gpu with rank0 can remain running
model = resnet50()
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
model.load_state_dict(torch.load(cfg.MODEL.pretrained_model_path))
model.eval()
if local_rank == 0:
acc, acc_std, th = lfw_test(model, cfg.TEST.lfw_root, cfg.TEST.lfw_test_list)
and the command code:
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 test.py
This method works, and through nvidia-smi
I saw that only GPU0 is working, but when I run another test process using GPU device 1 (when the previous one is still running):
CUDA_VISIBLE_DEVICES=1,2,3 python -m torch.distributed.launch --nproc_per_node=3 test.py
The previous process throw a runtime error:
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1549633347309/work/torch/lib/c10d/ProcessGroupNCCL.cpp:260, unhandled system error
There are 2 reasons for me to load the model using 1GPU:
- Some jobs have file-writing part and distributed parallel may cause wrong order.
- Running 4 tiny experiment with 1 GPU per process is more efficient for me to test my idea and finding bugs.
So is there a way that I can load the model like the common ways :torch.load_state_dict(torch.load()).to(torch.device("cuda:0))
?