Distributed Parallel, one machine multi gpu multi process?

Hi everyone,

I am trying to train a model with one machine, but with multi gpus. Until now, I was using the nn.DataParallel which works well, but it seems a bit slow to me so I would like to use the DistributedDataParallel instead.

However, I am not sure to understand clearly how to use this function (I have some weird results, the training takes 10x much more time than DataParallel).
In fact, I am not sure which gpus have to load the model/batch and compute the loss function ?
Moreover with the code below, my training is slower and I saw on nvidia-smi a weird behavior. Instead of having ONE process on each gpu, I have two process for each gpu (I have two gpus, but I have 4 process) .

My second issue is if I increase the number of workers in the dataloader, I have a dataloader pid killed error .
Am I doing something wrong ?

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from apex.parallel import DistributedDataParallel as DDP_apex
from torch.nn.parallel import DistributedDataParallel as DDP

def run(gpu, args):
	rank = gpu                   
        dist.init_process_group(                                   
    	        backend='nccl',                                         
   		init_method='tcp://localhost:1088', #'env://',                                   
    	        world_size=args.world_size,                              
    	       rank=rank                                               
              )      
	trainset = ...
	testset = ...
	
	################################################################
        train_sampler = torch.utils.data.distributed.DistributedSampler(
    	trainset,
    	num_replicas=args.world_size,
    	rank=rank
    )

    ################################################################
        test_sampler = torch.utils.data.distributed.DistributedSampler(
    	testset,
    	num_replicas=args.world_size,
    	rank=rank
    )

	trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,                        num_workers=args.workers, pin_memory=False, drop_last= True )
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False)
	optim_params = list(filter(lambda p: p.requires_grad, net.parameters()))
	optimizer = optim.Adam(optim_params, lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay, amsgrad=True)
	net = 2D_CNN()
	net = net.to(args.gpus[0])
	
	net = DDP(net, device_ids=args.gpus)
	
	train(net, optimizer,  trainloader, testloader, args, gpu) # function which iterate accross the dataloader and do the forward/backward/step

if __name__ =="__main__":
    args.nodes = 1 # one single machine
    args.gpus = [0,1,2]
    #########################################################
    args.world_size = len(args.gpus) * args.nodes                #
    os.environ['MASTER_ADDR'] = 'localhost'              #
    os.environ['MASTER_PORT'] = '8888'                      #
    print(args.gpus)
    print(os.environ['MASTER_PORT'] )
    mp.spawn(run, nprocs=len(args.gpus), args=(args,))         #
1 Like

Related to my question, read here

One requirement for DistributedDataParallel is that, you need to set device_ids properly or use CUDA_VISIBLE_DEVICES env var to configure them properly to make sure that one process only works on one GPU. Otherwise, by default, each process will try to use all visible GPUs.

Thanks for your replies.

@mrshenli How can I configure the devices_ids if I want to have one process = one gpu ?
In my case, I have tried to change device_ids by the gpu number (which is equal to the rank value in my case) such as device_ids = [gpu] and call transfer the model in the first gpu model.to(args.gpus[0]) .
But unfortunately, I got that error : RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

P.S : it seems to “work” now, I forgot to cast the loss function on the good gpu. But, when I increase the number of workers for the dataloader (>0) it creates the error : RuntimeError: DataLoader worker (pid(s) 16451) exited unexpectedly . (and the training took above 24 hours). I will try with the apex distributed and then again compare with DataParallel.

Please report back if you find that distributed is faster than data parallel (i.e. time it and lets us know). My intuition says that the claim is false but I haven’t seen anywhere on the docs a proper MWE on how to properly use distributed for each individual case in combination with launch.py.

cc @vincentqb on RuntimeError from DataLoader

This is a minimum DDP example. Posting it here in case it is useful for future readers.

That doesn’t cover launch.py examples.

I would like to see examples using launch.py covering the following two separate use cases

1 Like

@mrshenli Thanks for the link.

I see that my training take exactly the same amount of time with 1 ou 2 gpus, so I was wondering, when we use Distributed process, do we have to divide the original number of iteration by the number of gpus used ?

Hey @kirk86, could you please create an issue on GitHub for us to track this request? Thanks!

@Shiro

If each process runs the same amount of iterations with each iteration consuming the same amount of data, using 2GPUs might actually take longer, because there will be additional communication overhead between the two GPUs. But in this case, your model is actually trained using 2X number of batches.

Reducing the number of iterations should work, or you can also reduce the batch size. One thing to note is that, this might also call for additional tuning on the learning rate or other configs. Some relevant discussions are available here.

Done!

1 Like