DDP, how to make gpu augmentation happen in only the specified GPU

Hi! I’m trying to implement GPU augmentation with DDP. I was able to do implement the code with just one GPU, and wanted to expand it to work for multiple GPUs. However, it failed.

(note that the DDP part should have no problem, since it worked when I used CPU-based augmentations)

The general structure of the code is the following :

First when calling the dataset I pass in the gpu number (rank) as input

def main():
    ...
    torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)

def main_worker(gpu, args):
    args.rank += gpu
    torch.distributed.init_process_group(
        backend='nccl', init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)
    ...
    torch.cuda.set_device(gpu)
    ... 
    dataset = MRI_dataset(data_path, "train",1, transform_yAware_all,(1,99, 117, 95), MNI = True, gpu = args.rank)
    ....


  sampler = torch.utils.data.distributed.DistributedSampler(dataset)
      per_device_batch_size = args.batch_size // args.world_size
      loader = torch.utils.data.DataLoader(
          dataset, batch_size=per_device_batch_size, num_workers=args.workers, 
           sampler = sampler, drop_last = True)

where the MRI_dataset is the following

class MRI_dataset(Dataset):
    """
    split = 'train' or 'test'
    split_prop = '0~1'사이, how much to take as training for splitting
    transform = transform to be performed
    shape = shape of the image (MUST INCLUDE CHANNEL)
        * if MNI==True, does not do cropping or anything
        * if MNI==False, apply ResizeWithPadOrCrop
    MNI = fixed shape or not
    """
    
    def __init__(self,data_path, split, split_prop, transform, shape, MNI= True, gpu = 0): #add gpu as input
        assert len(shape) == 4, "shape must be given WITH the channel too! (C, H, W, D)"
        ...
        self.device = f"cuda:{gpu}"

    def __getitem__(self,idx):
        sub_data = self.dataset[idx]
        sub_img, sub_label = self.dataset[idx] #해당 idx subject의 img뽑기
        
        if self.split == 'train':
            """below : major revision, so check again (copy 안해도?)"""            
            y1 = self.transform(from_numpy(sub_img).float().to(self.device)) #load the dataloader worker to the correct gpu (each gpu does its own augmentation)
            y2 = self.transform_prime(from_numpy(sub_img).float().to(self.device))

            return (y1, y2), sub_label

This is the error log

Traceback (most recent call last):
  File "main_3D.py", line 372, in <module>
    main()
  File "main_3D.py", line 87, in main
    torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/scratch/connectome/dyhan316/VAE_ADHD/barlowtwins/main_3D.py", line 152, in main_worker
    for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)):
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 530, in __next__
    data = self._next_data()
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1224, in _next_data
    return self._process_data(data)
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1250, in _process_data
    data.reraise()
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/_utils.py", line 457, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 172, in default_collate
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 172, in <listcomp>
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 172, in default_collate
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 172, in <listcomp>
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/connectome/dyhan316/.conda/envs/VAE_3DCNN_older_MONAI/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 137, in default_collate
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Attempted to set the storage of a tensor on device "cuda:1" to a storage on different device "cuda:0".  This is no longer allowed; the devices must match.


torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2228.)
[W CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

I have another question that might be pertinent : in the code above, when I set num_workers, am I right in assuming that if I run the code with two gpu, since there’s two processes (one for each gpu), there will be two data loaders and hence there will be 4 workers? (2*2)?