Error in argparse

Hi, I am trying to use DistributedDataParallel in my work to speed up training. But in the following code, I got an error in argparse that I can’t understand what’s the reason for this error. It’s the first time that I am working with argparse. So, any help is really appreciated.

class Model(nn.Module):
    # Our model

    def __init__(self):
        super(Model, self).__init__()
        
        self.fc1 = nn.Conv2d(1,10,3)
        self.bn1 = nn.BatchNorm2d(10)
        self.fc2= nn.Conv2d(10,20,3)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc3= nn.Linear(11520,10)
        
    def forward(self,x):

        x = F.relu(self.fc1(x))
        x = self.bn1(x)
        x = F.relu(self.fc2(x))
        x = self.bn2(x)
        x = x.view(x.size(0),-1)
        x = self.fc3(x)
        return(x)

def train(gpu, args):
    torch.manual_seed(0)
    model = Model()
    torch.cuda.set_device(gpu)
    model.to(gpu)
    batch_size = 100
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(gpu)
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)
    # Data loading code
    train_dataset = torchvision.datasets.MNIST(root='./data',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=0,
                                               pin_memory=True)

    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(args.epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(non_blocking=True)
            labels = labels.to(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i + 1) % 100 == 0 and gpu == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch + 1, 
                    args.epochs, 
                    i + 1, 
                    total_step,
                    loss.item())
                   )
    if gpu == 0:
        print("Training complete in: " + str(datetime.now() - start))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
    
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')
    
    parser.add_argument('-nr', '--nr', default=0, type=int,
                        help='ranking within the nodes')
    

    args = parser.parse_args()
    
    train(0, args)

if __name__ == '__main__':
    main()

The error is:

SystemExit                                Traceback (most recent call last)
<ipython-input-32-c7bc734e5e35> in <module>
      1 if __name__ == '__main__':
----> 2     main()

<ipython-input-31-61eca122604a> in main()
     10 
     11 
---> 12     args = parser.parse_args()
     13 
     14     train(0, args)

~/anaconda3/lib/python3.7/argparse.py in parse_args(self, args, namespace)
   1750         if argv:
   1751             msg = _('unrecognized arguments: %s')
-> 1752             self.error(msg % ' '.join(argv))
   1753         return args
   1754 

~/anaconda3/lib/python3.7/argparse.py in error(self, message)
   2499         self.print_usage(_sys.stderr)
   2500         args = {'prog': self.prog, 'message': message}
-> 2501         self.exit(2, _('%(prog)s: error: %(message)s\n') % args)

~/anaconda3/lib/python3.7/argparse.py in exit(self, status, message)
   2486         if message:
   2487             self._print_message(message, _sys.stderr)
-> 2488         _sys.exit(status)
   2489 
   2490     def error(self, message):

SystemExit: 2

Try having the parser as a parameter to the main function:

def main(args):    
    train(0, args)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
    
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')
    
    parser.add_argument('-nr', '--nr', default=0, type=int,
                        help='ranking within the nodes')
    

    args = parser.parse_args()
    main(args)

Aha, Ok. Thank you so much.