Error in DistributedDataParallel

Hi, I am trying to use DistributedDataParallel for my job. I wrote two following codes but none of them is working properly. It would be great kind if someone helps me to find the problem.

class Model(nn.Module):
    # Our model

    def __init__(self):
        super(Model, self).__init__()
        
        self.fc1 = nn.Conv2d(1,10,3)
        self.bn1 = nn.BatchNorm2d(10)
        self.fc2= nn.Conv2d(10,20,3)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc3= nn.Linear(11520,10)
        
    def forward(self,x):
        
        print(f'inout_size: {x.size()}')
        
        x = F.relu(self.fc1(x))
        
        x = self.bn1(x)
        
        x = F.relu(self.fc2(x))
        
        x = self.bn2(x)
        
        x = x.view(x.size(0),-1)
        
        x = self.fc3(x)
        print(f'output_size: {x.size()}')
        return(x)
########################################    



def train(args):
    
    ########################################
    rank =args.gpui
    
    dist.init_process_group(backed = 'nccl',
                           init_method = 'env://',
                           world_size= args.world_size,
                           rank=rank)
    
    torch.manual_seed(0)
    
    model = Model()
    
    torch.cuda.set_device(args.gpui)
    model= model.to(device)
    optimizer = optim.Adam(model.parameters(),lr=0.1)
    lr_sch = lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.1)
    criterion = nn.CrossEntropyLoss().to(device)
    
    ######################################
    model = nn.DistributedDataParallel(model, device_ids = [args.gpui])
    #####################################

    
    mnist =torchvision.datasets.MNIST('./data',train= True,download=True,
                                      transform =transforms.ToTensor())
    
    ####################################
    train_sampler = torch.utils.data.distributed.DistributedSampler(mnist,
                                                                    num_replicas=args.world_size,
                                                                    rank = rank)
    
    ###################################
    
    dataloader = DataLoader(mnist,batch_size=32,num_workers =4,pin_memory=True,
                                               sampler = train_sampler)
    
    #####################################
    
    
    for epoch in range(num_epochs):

        total_loss =0
        
        for X,y in dataloader:   
        
            X= X.to(device)
            y = y.long().to(device)
            pred = model(X)
            loss = criterion(pred,y)
            t_loss+= loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Loss: {t_loss/len(dataloader)}')
if __name__=='__main__':
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('-n', '--nodes', default=1,
                        type=int, metavar='N')
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')
    
    parser.add_argument('-gi', '--gpui', default=3, type=int,
                        help='the index of gpu')
    
    parser.add_argument('-nr', '--nr', default=0, type=int,
                        help='ranking within the nodes')
    
    parser.add_argument('--epochs', default=2, type=int, 
                        metavar='N',
                        help='number of total epochs to run')
    args = parser.parse_args()
    #########################################################
    args.world_size = args.gpus * args.nodes                #  it is equal to the total number of gpus, because we use each gpu per node 
    os.environ['MASTER_ADDR'] = '172.20.24.55'              #  it tells which IP address it should look for process 0
    os.environ['MASTER_PORT'] = '8890'                      #
    mp.spawn(train,args=(args,),nprocs=args.world_size)         #

I got the following error,

--> 125     mp.spawn(train,args=(args,),nprocs=args.world_size)         #
    126     #########################################################
    127 

~/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
    198                ' torch.multiprocessing.start_process(...)' % start_method)
    199         warnings.warn(msg)
--> 200     return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')

~/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
    147             daemon=daemon,
    148         )
--> 149         process.start()
    150         error_queues.append(error_queue)
    151         processes.append(process)

~/anaconda3/lib/python3.7/multiprocessing/process.py in start(self)
    110                'daemonic processes are not allowed to have children'
    111         _cleanup()
--> 112         self._popen = self._Popen(self)
    113         self._sentinel = self._popen.sentinel
    114         # Avoid a refcycle if the target function holds an indirect

~/anaconda3/lib/python3.7/multiprocessing/context.py in _Popen(process_obj)
    282         def _Popen(process_obj):
    283             from .popen_spawn_posix import Popen
--> 284             return Popen(process_obj)
    285 
    286     class ForkServerProcess(process.BaseProcess):

~/anaconda3/lib/python3.7/multiprocessing/popen_spawn_posix.py in __init__(self, process_obj)
     30     def __init__(self, process_obj):
     31         self._fds = []
---> 32         super().__init__(process_obj)
     33 
     34     def duplicate_for_child(self, fd):

~/anaconda3/lib/python3.7/multiprocessing/popen_fork.py in __init__(self, process_obj)
     18         self.returncode = None
     19         self.finalizer = None
---> 20         self._launch(process_obj)
     21 
     22     def duplicate_for_child(self, fd):

~/anaconda3/lib/python3.7/multiprocessing/popen_spawn_posix.py in _launch(self, process_obj)
     40         tracker_fd = semaphore_tracker.getfd()
     41         self._fds.append(tracker_fd)
---> 42         prep_data = spawn.get_preparation_data(process_obj._name)
     43         fp = io.BytesIO()
     44         set_spawning_popen(self)

~/anaconda3/lib/python3.7/multiprocessing/spawn.py in get_preparation_data(name)
    170     # or through direct execution (or to leave it alone entirely)
    171     main_module = sys.modules['__main__']
--> 172     main_mod_name = getattr(main_module.__spec__, "name", None)
    173     if main_mod_name is not None:
    174         d['init_main_from_name'] = main_mod_name

AttributeError: module '__main__' has no attribute '__spec__'

One error I noticed is that, when using spawn, it will pass the rank as the first argument to the target function, followed by the args you provided. So the signature of the train function should be train(rank, args).

But above does not seem to be the cause of the logged error. That error does not seem to be PyTorch related, see this discussion: https://stackoverflow.com/questions/45720153/python-multiprocessing-error-attributeerror-module-main-has-no-attribute

Thank you for your answer. What about the following code? I tried it in another way. And I found the below error.


class Model(nn.Module):
    # Our model
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Conv2d(1,10,3)
        self.bn1 = nn.BatchNorm2d(10)
        self.fc2= nn.Conv2d(10,20,3)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc3= nn.Linear(11520,10)
        
    def forward(self,x):
        print(f'inout_size: {x.size()}')
        x = F.relu(self.fc1(x))
        x = self.bn1(x)
        x = F.relu(self.fc2(x))
        x = self.bn2(x)
        x = x.view(x.size(0),-1)
        x = self.fc3(x)
        print(f'output_size: {x.size()}')
        return(x)
########################################    



def train(gpu):
    rank = gpu
    
    dist.init_process_group(backed = 'nccl',
                           init_method = 'env://',
                           world_size= 4,
                           rank=rank)
    
    torch.manual_seed(0)
    
    model = Model()
    
    torch.cuda.set_device(gpu)
    model= model.to(device)
    optimizer = optim.Adam(model.parameters(),lr=0.1)
    lr_sch = lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.1)
    criterion = nn.CrossEntropyLoss().to(device)
    
    ######################################
    model = nn.DistributedDataParallel(model, device_ids = [gpu])
    #####################################

    
    mnist =torchvision.datasets.MNIST('./data',train= True,download=True,
                                      transform =transforms.ToTensor())
    ####################################
    train_sampler = torch.utils.data.distributed.DistributedSampler(mnist,
                                                                    num_replicas=4,
                                                                    rank = rank)
    
    ###################################
    
    dataloader = DataLoader(mnist,batch_size=32,num_workers =4,pin_memory=True,
                                               sampler = train_sampler)
    
    #####################################
    
    
    for epoch in range(10):
        total_loss =0
        for X,y in dataloader:   
            X= X.to(device)
            y = y.long().to(device)
            pred = model(X)
            loss = criterion(pred,y)
            t_loss+= loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()  
        print(f'Loss: {t_loss/len(dataloader)}')
     def main():
      os.environ['MASTER_ADDR'] = '172.20.24.55' ### the IP of vm_gpu02
      os.environ['MASTER_PORT'] = '9000'
      mp.spawn(train,nprocs=4)

if __name__=='__main__':
      main()

process 3 terminated with exit code 1

Which line threw the error? Could you please paste the trace as well?

I might miss sth, but looks like the device var is undefined? Did you mean gpu instead?

Oh, yes. Sorry, I was running several codes today simultaneously. that’s why I didn’t notice this mistake.

Hey, I corrected the mistake, it wasn’t the problem though. Here you can see the traceback error. Actually I want to run the model on 4 gpus and I declared it as nproc=4, but I do not know if I should add something else to my code or not? At the moment it only reads one gpu.

class Model(nn.Module):
    # Our model
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Conv2d(1,10,3)
        self.bn1 = nn.BatchNorm2d(10)
        self.fc2= nn.Conv2d(10,20,3)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc3= nn.Linear(11520,10) 
    def forward(self,x):
        print(f'inout_size: {x.size()}')
        x = F.relu(self.fc1(x))
        x = self.bn1(x)
        x = F.relu(self.fc2(x))
        x = self.bn2(x)
        x = x.view(x.size(0),-1)
        x = self.fc3(x)
        print(f'output_size: {x.size()}')
        return(x)
########################################    
def train(gpu):
    rank = gpu 
    dist.init_process_group(backed = 'nccl',
                           init_method = 'env://',
                           world_size= 4,
                           rank=rank)
    torch.manual_seed(0)
    model = Model()  
    torch.cuda.set_device(gpu)
    model= model.to(gpu)
    optimizer = optim.Adam(model.parameters(),lr=0.1)
    lr_sch = lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.1)
    criterion = nn.CrossEntropyLoss().to(gpu)
    ######################################
    model = nn.DistributedDataParallel(model, device_ids = [gpu])
    #####################################
    mnist =torchvision.datasets.MNIST('./data',train= True,download=True,
                                      transform =transforms.ToTensor())
    ####################################
    train_sampler = torch.utils.data.distributed.DistributedSampler(mnist,
                                                                    num_replicas=4,
                                                                    rank = rank)
    ###################################
    dataloader = DataLoader(mnist,batch_size=32,num_workers =4,pin_memory=True,
                                               sampler = train_sampler)
    #####################################
    for epoch in range(10):
        total_loss =0
        for X,y in dataloader:   
            X= X.to(gpu)
            y = y.long().to(gpu)
            pred = model(X)
            loss = criterion(pred,y)
            t_loss+= loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Loss: {t_loss/len(dataloader)}')
def main():
    os.environ['MASTER_ADDR'] = '172.20.24.55' ### the IP of vm_gpu02
    os.environ['MASTER_PORT'] = '9000'
    mp.spawn(train,nprocs=4)
if __name__=='__main__': 
    main()

Exception                                 Traceback (most recent call last)
<ipython-input-10-e18ebd33df91> in <module>
      1 if __name__=='__main__':
      2 
----> 3     main()

<ipython-input-9-331de420a7b8> in main()
      5     os.environ['MASTER_PORT'] = '9000'
      6 
----> 7     mp.spawn(train,nprocs=4)

~/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
    198                ' torch.multiprocessing.start_process(...)' % start_method)
    199         warnings.warn(msg)
--> 200     return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')

~/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
    156 
    157     # Loop on join until it returns True or raises an exception.
--> 158     while not context.join():
    159         pass
    160 

~/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    111                 raise Exception(
    112                     "process %d terminated with exit code %d" %
--> 113                     (error_index, exitcode)
    114                 )
    115 

Exception: process 2 terminated with exit code 1

Found a few errors when debugging this locally:

  1. in DDP ctor, the arg name is backend instead of backed
  2. DDP is from torch.nn.parallel package instead of torch.nn
  3. t_loss is used before definition.

The following code works for me. I tried it on 2 GPUs, as I only have 2 in my dev env. Some general suggestion for debugging: 1) it will be helpful if you can locate which line threw the error, 2) it will be easier to debug if you start from a simpler version and gradually add complexity to the code.

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

class Model(nn.Module):
    # Our model
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Conv2d(1,10,3)
        self.bn1 = nn.BatchNorm2d(10)
        self.fc2= nn.Conv2d(10,20,3)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc3= nn.Linear(11520,10)
    def forward(self,x):
        print(f'inout_size: {x.size()}')
        x = F.relu(self.fc1(x))
        x = self.bn1(x)
        x = F.relu(self.fc2(x))
        x = self.bn2(x)
        x = x.view(x.size(0),-1)
        x = self.fc3(x)
        print(f'output_size: {x.size()}')
        return(x)
########################################
def train(gpu):
    print("1111")
    rank = gpu
    dist.init_process_group(backend = 'nccl',
                           init_method = 'env://',
                           world_size= 2,
                           rank=rank)
    print("2222")
    torch.manual_seed(0)
    model = Model()
    torch.cuda.set_device(gpu)
    model= model.to(gpu)
    optimizer = optim.Adam(model.parameters(),lr=0.1)
    lr_sch = lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.1)
    criterion = nn.CrossEntropyLoss().to(gpu)
    ######################################
    model = nn.parallel.DistributedDataParallel(model, device_ids = [gpu])
    #####################################
    mnist =torchvision.datasets.MNIST('./data',train= True,download=True,
                                      transform =transforms.ToTensor())
    ####################################
    train_sampler = torch.utils.data.distributed.DistributedSampler(mnist,
                                                                    num_replicas=4,
                                                                    rank = rank)
    ###################################
    dataloader = DataLoader(mnist,batch_size=32,num_workers =4,pin_memory=True,
                                               sampler = train_sampler)
    #####################################
    t_loss = None
    for epoch in range(2):
        total_loss =0
        for X,y in dataloader:
            X= X.to(gpu)
            y = y.long().to(gpu)
            pred = model(X)
            loss = criterion(pred,y)
            t_loss= loss.item() if t_loss is None else t_loss + loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Loss: {t_loss/len(dataloader)}')

def main():
    os.environ['MASTER_ADDR'] = 'localhost' ### the IP of vm_gpu02
    os.environ['MASTER_PORT'] = '9000'
    mp.spawn(train,nprocs=2)
if __name__=='__main__':
    main()

Hi, thank you for your answer. Regarding your code, when we run the code isn’t it needed to give the gpu index? How dose coed understand on which gpus should run the process.

And my second question is a bout num_replicas, shouldn’t it be equal to num_process?

A process can access any visible GPU. The one-process-per-GPU requirement is from DDP to avoid NCCL comm hang. Ideally, we should set CUDA_VISIBLE_DEVICES for each process accordingly, so that each process only sees one GPU and cuda:0 on each process points to a different GPU. But if you are confident that no code would accidentally access a different GPU, directly doing .to(gpu) would be sufficient. We are using the id of the process provided by mp.spawn as the gpu id.

And my second question is a bout num_replicas , shouldn’t it be equal to num_process?

Yep, you are right.