Data scattering with DistributedDataParallel

Hello, I’m trying to load my data with DistributedSampler class in order to train model on multiple GPUs. The model is wrapped with DistributedDataParallel. The data is successfully loaded on my 2x GPUs. Here my code snippets:


            # distributed learning
            if torch.cuda.device_count() > 1:
                 model = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[range(self.num_gpus)])
            else:
                model = self.net

            iteration = infos["iteration"]
            epoch_start = infos["epoch"]
            
            model.train()
            for epoch in range(epoch_start, cfg.TRAIN.MAX_EPOCH):
                    self.setup_dataloader(epoch=epoch)

                    for _, blobs in enumerate(self.loader):
                        print("blobs.size", len(blobs))
                        print(blobs)
                        loss_dict = model.forward(blobs)

blobs is list of dicts which include tensors, objects in images + other additional information (It’s a object detection task based on Faster CNN).
After calling model.forward(blobs), there is a error reported as:

TypeError: list indices must be integers or slices, not range

The corresponding traceback of this error:

Traceback (most recent call last):
  File "tools/train.py", line 456, in <module>
    trainer.train(args)
  File "tools/train.py", line 372, in train
    loss_dict = model.forward(blobs)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 445, in forward
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 471, in scatter
    return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 36, in scatter_kwargs
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 28, in scatter
    res = scatter_map(inputs)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 15, in scatter_map
    return list(zip(*map(scatter_map, obj)))
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 17, in scatter_map
    return list(map(list, zip(*map(scatter_map, obj))))
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 19, in scatter_map
    return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 15, in scatter_map
    return list(zip(*map(scatter_map, obj)))
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 13, in scatter_map
    return Scatter.apply(target_gpus, None, dim, obj)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 88, in forward
    streams = [_get_stream(device) for device in target_gpus]
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 88, in <listcomp>
    streams = [_get_stream(device) for device in target_gpus]
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 115, in _get_stream
    if _streams[device] is None:
TypeError: list indices must be integers or slices, not range
Traceback (most recent call last):
  File "tools/train.py", line 456, in <module>
    trainer.train(args)
  File "tools/train.py", line 372, in train
    loss_dict = model.forward(blobs)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 445, in forward
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 471, in scatter
    return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 36, in scatter_kwargs
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 28, in scatter
    res = scatter_map(inputs)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 15, in scatter_map
    return list(zip(*map(scatter_map, obj)))
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 17, in scatter_map
    return list(map(list, zip(*map(scatter_map, obj))))
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 19, in scatter_map
    return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 15, in scatter_map
    return list(zip(*map(scatter_map, obj)))
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 13, in scatter_map
    return Scatter.apply(target_gpus, None, dim, obj)
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 88, in forward
    streams = [_get_stream(device) for device in target_gpus]
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 88, in <listcomp>
    streams = [_get_stream(device) for device in target_gpus]
  File "/vol/.conda/envs/.env36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 115, in _get_stream
    if _streams[device] is None:
TypeError: list indices must be integers or slices, not range

As far as I know that if the input of model is tensor data, there will be no problem to train model on mutliple GPUs distributedly. Might it be possible that a list is employed to pass the data in model.forward() methods.

It works if I launch the model only on single GPU.

Thanks in advance.

I made a typo for passing devices : instead of torch.cuda.set_device(args.local_rank), I passed wrong parameter to torch.cuda.set_device(range(2)).

After fixing this typo, I still have the same problem as posted as How to scatter list data on multiple GPUs

Thanks for any inputs.

Checked the scatter implementation, and looks like it can scatter tensors in dictionaries properly. What is the structure of the blobs var that scatter fails to handle?

>>> x = {1:1, 2:2}
>>> x.items()
dict_items([(1, 1), (2, 2)])
>>> import torch
>>> from torch.nn.parallel.scatter_gather import scatter
>>> scatter(x, target_gpus=[0, 1])
[{1: 1, 2: 2}, {1: 1, 2: 2}]
>>> y = {1: torch.zeros(4, 4).to(0), 2: torch.zeros(4, 4).to(0)}
>>> scatter(y, target_gpus=[0, 1])
[{1: tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0'), 2: tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')}, {1: tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:1'), 2: tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:1')}]

Hey @mrshenli, Thanks for your reply.

I solved this issue already. Basically, there are two methods (three, plus your idea:D).
These ideas are mostly based on DistributedSampler and DistributedDataParallel . Using the DistributedSampler, a subset of data can be loaded correctly in this process.

We can either use torch.multiprocessing to spawn a process manually to launch our training procedure,

def main():
   # Args defintion and loading
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8080'
    mp.spawn(train, nprocs=args.num_gpus, args=(args,))

def train(gpu, args):
    # Initialize the distributed package&group
    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.local_rank)
    torch.manual_seed(0)

    # Initial model
    model = RCNN(cfg)
    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    
    # Wrap the model
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    # Initial dataset
    dataset = Dataset(cfg, is_train=True, split="train")

    # Kick-off iteration
    for epoch in range(cfg.TRAIN.MAX_EPOCH):
        loader = setup_dataloader(dataset,  is_distributed=True, epoch)
            
        for _, blobs in enumerate(loader):
             loss_dict = model.forward(blobs)

if __name__ == "__main__":
    main()

Passing args by using:

python3  tools/train_mp.py  --num_gpus 2

or we can use what torch already encapsulated:

class TrainingWrapper(object):
    def __init__(self, args):
        self.setup_logging()
        self.args = args

        # Initialize the distributed package&group
        self.num_gpus = torch.cuda.device_count()
        print("world_size:%d\local_rank:%d" % (args.num_gpus, args.local_rank))
        self.distributed = self.num_gpus > 1
        if self.distributed:
            torch.distributed.init_process_group(
                backend="nccl",
                init_method="env://",
                world_size=args.num_gpus,
                rank=args.local_rank
            )
        self.device = args.local_rank 
        
       # This line is very important!
        torch.cuda.set_device(self.device)

        # Initial model
        model = RCNN(cfg)

        # Distributed learning
        if self.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model.cuda(self.device),
                device_ids=[self.args.local_rank],
                output_device=[self.args.local_rank],
                broadcast_buffers=False
            )
        else:
            model = torch.nn.DataParallel(model).cuda()

    def train(self):

        # Initial dataset
         dataset = Dataset(cfg, is_train=True, split="train")

        # Kick-off iteration
       for epoch in range(cfg.TRAIN.MAX_EPOCH):
            loader = setup_dataloader(dataset,  is_distributed=True, epoch)
            
            for _, blobs in enumerate(loader):
                 loss_dict = model.forward(blobs)

if  __name__ == "__main__":
    trainer = TrainingWrapper()
    trainer.train()

And passing following args for this above script:

python3 -m torch.distributed.launch --nproc_per_node=2 tools/train_ddp.py  --exp_id boa --config_file experiments/config.yaml --num_gpus 2

As so far, I can just use list of dict objects in order to feed the data in model.
Hopefully it helps somebody else somehow.

1 Like