Using torch.nn.parallel.DistributedDataParalle with DALI

Hi, I have an interactive pipe1 -> pipe2-> NN workflow which is explained here
I want to parallelize this in a distributed memory system which has 2 GPUs per node
I want to put one pipe1 -> pipe2-> NN apparatus per process (rank) and map one rank per GPU.
As a result I will have two (pipe1 -> pipe2-> NN)s per node
I am using this example
I am sending you my file so you can see everything there, but the error occurs in the following lines of code

210         # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel.
211         if args.distributed:
212                 model = DDP(model, device_ids=[args.gpu], output_device=args.gpu)
213                 if args.verbose:
214                         print('Since we are in a distributed setting the model is replicated here in rank {}' .format(args.local_rank))

the output from the run is the following:
please ignore the three first ERRORS since they are related to containerization issues I guess

Singularity> python3.8 -m torch.distributed.launch --nproc_per_node=2 Contrastive_Learning.py --b 128 -v /projects/neurophon/
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
distributed is True, then rank number 1 is mapped in device number 1
Using device cuda:1 in rank number 1
distributed is True, then rank number 0 is mapped in device number 0
Using device cuda:0 in rank number 0
function_f created from rank 1
function_f created from rank 0
function_g created from rank 1
SimCLR_Module created from rank 1
function_g created from rank 0
SimCLR_Module created from rank 0
pipe1 built by rank number 1 in 1.4155216217041016 seconds
Initialating fixation

pipe2 built by rank number 1 in 0.0034644603729248047 seconds
pipe1 built by rank number 0 in 1.418199062347412 seconds
Initialating fixation

pipe2 built by rank number 0 in 0.003275632858276367 seconds
Traceback (most recent call last):
  File "Contrastive_Learning.py", line 294, in <module>
    main()
  File "Contrastive_Learning.py", line 212, in main
    model = DDP(model, device_ids=[args.gpu], output_device=args.gpu)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 335, in __init__
    self._ddp_init_helper()
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 441, in _ddp_init_helper
    self.reducer = dist.Reducer(
RuntimeError: CUDA error: the launch timed out and was terminated
terminate called after throwing an instance of 'dali::CUDAError'
  what():  CUDA runtime API error cudaErrorLaunchTimeout (6):
the launch timed out and was terminated
Traceback (most recent call last):
  File "Contrastive_Learning.py", line 294, in <module>
    main()
  File "Contrastive_Learning.py", line 212, in main
    model = DDP(model, device_ids=[args.gpu], output_device=args.gpu)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 335, in __init__
    self._ddp_init_helper()
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 441, in _ddp_init_helper
    self.reducer = dist.Reducer(
RuntimeError: CUDA error: the launch timed out and was terminated
terminate called after throwing an instance of 'dali::CUDAError'
  what():  CUDA runtime API error cudaErrorLaunchTimeout (6):
the launch timed out and was terminated
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launch.py", line 261, in <module>
    main()
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launch.py", line 256, in main
    raise subprocess.CalledProcessError(returncode=process.returncode,
subprocess.CalledProcessError: Command '['/usr/local/bin/python3.8', '-u', 'Contrastive_Learning.py', '--local_rank=1', '--b', '128', '-v', '/projects/neurophon/']' died with <Signals.SIGABRT: 6>.
Singularity> exit

Basically the error is on line 212

What am I doing wrong?

Thanks!

Here I provide the complete code

import argparse
import sys
import os

import torch
import torch.optim as optim

import torch.distributed.autograd as dist_autograd
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef

from time import time

sys.path.append('SimCLR/NVIDIA DALI')
import NVIDIA_DALI_Pipelines as NDP
sys.path.append('SimCLR/ResNet')
import ResNet as rn
sys.path.append('SimCLR/MLP')
import multilayerPerceptron as mlp
sys.path.append('SimCLR')
import SimCLR

def parse():
        parser = argparse.ArgumentParser(prog='Contrastive_Learning',
                                         description='This program executes the Contrastive Learning Algorithm using foveated saccades')
        parser.add_argument('data', metavar='DIR', default='/projects/neurophon', type=str,
                            help='path to the MSCOCO dataset')
        parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                            help='number of data loading workers (default: 4)')
        parser.add_argument('--epochs', default=90, type=int, metavar='N',
                            help='number of total epochs to run')
        parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                            help='manual epoch number (useful on restarts)')
        parser.add_argument('-b', '--batch-size', default=256, type=int,
                            metavar='N', help='mini-batch size per process (default: 256)')
        parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                            metavar='LR', help='Initial learning rate.  Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
        parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                            help='momentum')
        parser.add_argument('--temperature', default=0.05, type=float, metavar='T',
                            help='SimCLR temperature')
        parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                            metavar='W', help='weight decay (default: 1e-4)')
        parser.add_argument('--print-freq', '-p', default=10, type=int,
                            metavar='N', help='print frequency (default: 10)')
        parser.add_argument('--resume', default='', type=str, metavar='PATH',
                            help='path to latest checkpoint (default: none)')
        parser.add_argument('--dali_cpu', action='store_true',
                            help='Runs CPU based version of DALI pipeline.')
        parser.add_argument('--prof', default=-1, type=int,
                            help='Only run 10 iterations for profiling.')
        parser.add_argument('--deterministic', action='store_true')
        parser.add_argument("--local_rank", default=0, type=int)
        parser.add_argument('-t', '--test', action='store_true',
                            help='Launch test mode with preset arguments')
        parser.add_argument('-v', '--verbose', action='store_true',
                            help='provides additional details as to what the program is doing')
        args = parser.parse_args()
        return args


def main():
        global args
        args = parse()


        if not len(args.data):
                raise Exception("error: No data set provided")

        args.distributed = False
        if 'WORLD_SIZE' in os.environ:
                args.distributed = int(os.environ['WORLD_SIZE']) > 1
        

        args.gpu = 0
        args.world_size = 1
        
        if args.distributed:
                args.gpu = args.local_rank
                torch.cuda.set_device(args.gpu)
                torch.distributed.init_process_group(backend='nccl', init_method='env://')
                args.world_size = torch.distributed.get_world_size()
                if args.verbose:
                        print('distributed is True, then rank number {} is mapped in device number {}' .format(args.local_rank, args.gpu))

        args.total_batch_size = args.world_size * args.batch_size


        # Set the device
        device = torch.device('cpu' if args.dali_cpu else 'cuda:' + str(args.gpu))
        if args.verbose:
               print('Using device {} in rank number {}' .format(device, args.local_rank))


        # Set fuction_f
        function_f = rn.ResNet.ResNet18()
        function_f.to(device)
        if args.verbose:
               print('function_f created from rank {}' .format(args.local_rank))
        

        # Set function_g
        function_g = mlp.MLP(512*4*4, 1024, 128)
        function_g.to(device)
        if args.verbose:
               print('function_g created from rank {}' .format(args.local_rank))
        

        # Set SimCLR model
        img_size = (30,30)
        model = SimCLR.SimCLR_Module(args.temperature, function_f, function_g, args.batch_size, img_size, device)
        model.to(device)
        if args.verbose:
               print('SimCLR_Module created from rank {}' .format(args.local_rank))
        
        
        # # Set optimizer
        # args.lr = args.lr*float(args.batch_size*args.world_size)/256.
        # optimizer = optim.SGD(model.parameters(), args.lr,
                              # momentum=args.momentum,
                              # weight_decay=args.weight_decay)

        # Optionally resume from a checkpoint
        if args.resume:
             # Use a local scope to avoid dangling references
             def resume():
                if os.path.isfile(args.resume):
                    print("=> loading checkpoint '{}'" .format(args.resume))
                    checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
                    args.start_epoch = checkpoint['epoch']
                    model.load_state_dict(checkpoint['state_dict'])
                    optimizer.load_state_dict(checkpoint['optimizer'])
                    print("=> loaded checkpoint '{}' (epoch {})"
                                    .format(args.resume, checkpoint['epoch']))
                else:
                    print("=> no checkpoint found at '{}'" .format(args.resume))
        
             resume()


        path = args.data
        os.environ['DALI_EXTRA_PATH']=path
        test_data_root = os.environ['DALI_EXTRA_PATH']
        file_root = os.path.join(test_data_root, 'MSCOCO', 'cocoapi', 'images', 'val2014')
        annotations_file = os.path.join(test_data_root, 'MSCOCO', 'cocoapi', 'annotations', 'instances_val2014.json')

        # This is pipe1, using this we bring image batches from MSCOCO dataset
        pipe1 = NDP.COCOReader(batch_size=args.batch_size,
                               num_threads=args.workers,
                               device_id=args.local_rank,
                               file_root=file_root,
                               annotations_file=annotations_file,
                               shard_id=args.local_rank,
                               num_shards=args.world_size,
                               dali_cpu=args.dali_cpu)

        start = time()
        pipe1.build()
        total_time = time() - start
        if args.verbose:
               print('pipe1 built by rank number {} in {} seconds' .format(args.local_rank, total_time))


        # This is pipe2, which is used to augment the batches brought by pipe1 utilizing foveated saccades
        images = NDP.ImageCollector()
        fixation = NDP.FixationCommand(args.batch_size)
        pipe2 = NDP.FoveatedRetinalProcessor(batch_size=args.batch_size,
                                             num_threads=args.workers,
                                             device_id=args.local_rank,
                                             fixation_information=fixation,
                                             images=images,
                                             dali_cpu=args.dali_cpu)
        start = time()
        pipe2.build()
        total_time = time() - start
        if args.verbose:
               print('pipe2 built by rank number {} in {} seconds' .format(args.local_rank, total_time))
        


        # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel.
        if args.distributed:
                model = DDP(model, device_ids=[args.gpu], output_device=args.gpu)
                if args.verbose:
                        print('Since we are in a distributed setting the model is replicated here in rank {}' .format(args.local_rank))



if __name__ == '__main__':
        main()

Perhaps you are running into this issue: https://forums.developer.nvidia.com/t/xid-8-in-various-cuda-deep-learning-applications-for-nvidia-gtx-1080-ti/66433?

It seems like one of the underlying causes of that error is the cuda kernel taking too long due to the GPU being used for other processes (like driving the display), which starves your training process and eventually aborts it. Is that forum post applicable here?

Thank you @osalpekar, I don’t think so. I am runing this on a dedicated node in a cluster. No display, nobody is using such node except me.

Is your code working without DDP and without DALI?
If so, have you tried the use case without one of these packages?
Also, which PyTorch, DALI, CUDA versions and which GPU are you using?

Thank you @ptrblck

My code is working only withour DDT (when I comment out the code below)

210         # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel.
211         if args.distributed:
212                 model = DDP(model)
213                 # model = DDP(model, device_ids=[args.gpu], output_device=args.gpu)
214                 if args.verbose:
215                         print('Since we are in a distributed setting the model is replicated here in rank {}' .format(args.local_rank))

This is the output without DDT

Singularity> python3.8 -m torch.distributed.launch --nproc_per_node=2 Contrastive_Learning.py --b 128 -v /projects/neurophon/
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
distributed is True, then rank number 1 is mapped in device number 1
Using device cuda:1 in rank number 1
distributed is True, then rank number 0 is mapped in device number 0
Using device cuda:0 in rank number 0
function_f created from rank 0
function_g created from rank 0
SimCLR_Module created from rank 0
function_f created from rank 1
function_g created from rank 1
SimCLR_Module created from rank 1
pipe1 built by rank number 0 in 3.062119245529175 seconds
Initialating fixation

pipe2 built by rank number 0 in 0.0029039382934570312 seconds
pipe1 built by rank number 1 in 2.993931531906128 seconds
Initialating fixation

pipe2 built by rank number 1 in 0.003136157989501953 seconds
Singularity> exit

This is the output when I comment out all that has to do with DALI and uncomment DDT section

Singularity> python3.8 -m torch.distributed.launch --nproc_per_node=2 Contrastive_Learning.py --b 128 -v /projects/neurophon/
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
distributed is True, then rank number 1 is mapped in device number 1
Using device cuda:1 in rank number 1
distributed is True, then rank number 0 is mapped in device number 0
Using device cuda:0 in rank number 0
function_f created from rank 0
function_f created from rank 1
function_g created from rank 0
SimCLR_Module created from rank 0
function_g created from rank 1
SimCLR_Module created from rank 1
Traceback (most recent call last):
  File "Contrastive_Learning.py", line 295, in <module>
    main()
  File "Contrastive_Learning.py", line 213, in main
    model = DDP(model, device_ids=[args.gpu], output_device=args.gpu)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 335, in __init__
    self._ddp_init_helper()
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 441, in _ddp_init_helper
    self.reducer = dist.Reducer(
RuntimeError: CUDA error: the launch timed out and was terminated
NCCL error in: /pytorch/torch/lib/c10d/../c10d/NCCLUtils.hpp:69, unhandled cuda error, NCCL version 2.4.8
Traceback (most recent call last):
  File "Contrastive_Learning.py", line 295, in <module>
    main()
  File "Contrastive_Learning.py", line 213, in main
    model = DDP(model, device_ids=[args.gpu], output_device=args.gpu)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 335, in __init__
    self._ddp_init_helper()
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 441, in _ddp_init_helper
    self.reducer = dist.Reducer(
RuntimeError: CUDA error: the launch timed out and was terminated
NCCL error in: /pytorch/torch/lib/c10d/../c10d/NCCLUtils.hpp:69, unhandled cuda error, NCCL version 2.4.8
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launch.py", line 261, in <module>
    main()
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launch.py", line 256, in main
    raise subprocess.CalledProcessError(returncode=process.returncode,
subprocess.CalledProcessError: Command '['/usr/local/bin/python3.8', '-u', 'Contrastive_Learning.py', '--local_rank=1', '--b', '128', '-v', '/projects/neurophon/']' died with <Signals.SIGABRT: 6>.
Singularity> exit

Now let’s go to versions

In the container I instal the following version of dali
# nvidia dali
pip3.8 --no-cache-dir --disable-pip-version-check install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100

Also:

Singularity> cat /etc/centos-release
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
CentOS Linux release 7.8.2003 (Core)


Singularity> python3.8
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
Python 3.8.3 (default, Oct 16 2020, 20:24:57) 
[GCC 4.8.5 20150623 (Red Hat 4.8.5-39)] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__version__)
1.6.0
>>> 


Singularity> nvidia-smi 
ERROR: ld.so: object '/soft/buildtools/trackdeps/${LIB}/trackdeps.so' from LD_PRELOAD cannot be preloaded: ignored.
Tue Oct 27 12:46:37 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 440.64.00    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla K80           On   | 00000000:07:00.0 Off |                    0 |
| N/A   27C    P8    33W / 149W |     70MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla K80           On   | 00000000:08:00.0 Off |                    0 |
| N/A   21C    P8    29W / 149W |     69MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      1181      G   /usr/bin/Xorg                                 68MiB |
|    1      1181      G   /usr/bin/Xorg                                 68MiB |
+-----------------------------------------------------------------------------+
Singularity> 

@dariodematties To simplify the investigation here, can you check if simple torch CUDA operations work fine (ex: creating tensors on GPUs and performing some basic ops on them)? In addition to this, testing NCCL collective operations like allreduce independently would be a good idea to see if there is something related to NCCL here.

model = DDP(model)

In this case, I’m assuming the model is actually on CPU and that’s why its working?

Thank you very much @pritamdamania87

I tested it changing the backend in torch.distributed.init_process_group(backend='nccl', init_method='env://')
from nccl to gloo and it worked

As far as I could see here there are several communication operations in gloo which are not available on GPUs.
On gloo we have just broadcast and all_reduce available on GPU. I do not know if this will affect the performance of the code in some way regarding DDP.
From DDP documentation I interpret that a Reducer uses allreduce operation and then all gradients are averaged in each replicated model in each process.
Yet, I am not sure how using gloo instead of nccl may affect performance
Thanks!

Yet, I am not sure how using gloo instead of nccl may affect performance

NCCL is much more performant than GLOO when dealing with GPU tensors (we’ve seen a difference of 1.5x - 1.8x in some cases). If performance is important for you, you should try to stick with NCCL and get it to work.

Thank you @pritamdamania87!

At least I know where is the problem now and can start making some tests before implementing larger models

Surely I will ask the support in the cluster since this seems to be a deeper issue from the packages

@dariodematties So if we use DALI with distributed computing, we should use gloo? Any update with nccl?

I do not think so @Johnson_Mark, I guess it depends on your container configuration (if you are using containerization) and on the nvidia libaries intalled in the machine but I am not sure.
The truth is that I have used nccl with DALI in another machine with another container