Load multiple models with DistributedDataParallel

I found out the imbalance of GPU usages on my implementation.

def train(device, args):
    torch.distributed.init_process_group(backend='nccl', rank=device, world_size=torch.cuda_device_count())
    model_A = A()
    model_B = B()
    
    # Not gonna update A
    model_A.to(device)
    ckpt = torch.load(path)
    model_A.load_state_dict(ckpt['model_state_dict '])
    model_A = torch.nn.parallel.DIstributedDataParallel(model_A, device_ids=[device])
    model_A.eval()

    # Gonna update B
    model_B.to(device)
    model_B = torch.nn.parallel.DistributedDataParallel(model_B, device_ids=[device])
    model_B.train()        

if __name__ == '__main__':
    args = argparse()
    torch.multiprocessing.spawn(train, nprocs=torch.cuda.device_count(), args=(args, ))

I intend to load both models A and B on GPUs however,
GPU usage tells me only model_A is allocated on GPU:0 and not allocated on the other GPUs.
like below,

GPU 0: 7000MiB / 11019MiB (model A, B)
GPU 1: 4000MiB / 11019MiB (model B)
GPU 2: 4000MiB / 11019MiB (model B)
GPU 3: 4000MiB / 11019MiB (model B)

Please ask me free if you have any unclears.

It could be all processes unintentionally created CUDA context on the default GPU (cuda:0). To avoid this situation, can you try setting CUDA_VISIBLE_DEVICES env var to a different device for each process, so that each process would only see one GPU?

I solved this issue by map_location='cpu’
Loading pretrained model

ckpt = torch.load(path)

automatically allocates the parameters to GPU:0.
I cannot understand why until now, but it is solved by using map_location while I load the model.

ckpt = torch.load(path, map_location='cpu')
model_A.load_state_dict(ckpt['model_state_dict '])
model_A.to(device)

I cannot understand why until now, but it is solved by using map_location while I load the model.

I think this is because, all processes are trying to load the model to cuda:0 by default if you don’t set map_location or CUDA_VISIBLE_DEVICES. BTW, does directly setting map_location to device work for you?