NaN loss with dataparallel

I’m a bit stumped here. I have a semantic segmentation model that trains fine on a single gpu, but when I try dataparallel with two my loss increases until I get NaN. Any common pitfalls that I should watch out for with dataparallel and semantic segmentation?

Could you post a small code snippet reproducing this error?
It’s strange, that your loss suddenly increases using DataParallel while it was decreasing using a single GPU.
Did you change any other parts of your code, e.g. unintentionally removing optimizer.zero_grad() or moving it out of the training loop?

I’ll try to figure out how to reproduce it with a smaller example, right now the code is pretty big.

However the only differences that I have between multi and single gpu are:

if args.gpu == 99:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")

and this

if args.gpu == 99:
    net = nn.DataParallel(net)
net.to(device)

basically if I set the arg to 99 it’ll use dataparallel. Not sure what I’m missing.

In the meantime while you try to create a code snippet, we could continue the discussion and try to figure out, what’s going on.

  • If you use DataParallel with 2 GPUs, do you also double the batch size?
  • What batch size are you using? (Single and multiple GPU)
  • Do you have any BatchNorm layers?
  • Does your loss increase immediately after the first iteration or does it decrease for a period of time?
  • Yup, I go from a batch size of 12 to 24
  • There are batchnorm layers. I’m using a pretrained SEResNeXt50 from, https://github.com/Cadene/pretrained-models.pytorch as the backbone to a UNet.
  • Loss increases immediately after the first batch, gets to NaN in 2-3 batches.

Could you try to lower the learning rate or momentum step by step and see if and when the loss starts to decrease?
If it just stalls, we would have to try to debug your code I guess.
However, if the loss starts to decrease, I’m wondering if your issue might be related to the problem of accumulating momentum for distributed workers. I’ll try to dig up the paper and see, if it might be related.

Oh interesting. Could segmentation somehow exacerbate that problem since we are doing per-pixel classification so to speak?

Lowering the loss or momentum didn’t help.

Lowering the loss or momentum didn’t help.

I’m going to continue training this with a single gpu for now. It should take about 5 days and then I can try some more things and pick this conversation back up.

I finally had time to create a minimal example. Both train and validation loss goes to nan pretty quickly.

import time
import numpy as np
import argparse

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as tforms

import torchvision as vsn

parser = argparse.ArgumentParser(description='CIFAR10 Training')
parser.add_argument('--lr', default=0.001, type=float,
                    help='initial learning rate')
parser.add_argument('--epochs', default=20, type=int, 
                    help='number of epochs')
parser.add_argument('--batch_size', default=256, type=int, 
                    help='size of batches')
parser.add_argument('--gpu', default=0, type=int, 
                    help='which gpu to run')
args = parser.parse_args()

class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.resnet_layers = list(vsn.models.resnet18(pretrained=False).children())
        self.maxpool = self.resnet_layers[3]
        # encoding layers
        self.encode_a = nn.Sequential(*self.resnet_layers[:3])
        self.encode_b = self.resnet_layers[4]
        self.encode_c = self.resnet_layers[5]
        self.encode_d = self.resnet_layers[6]
        self.encode_e = self.resnet_layers[7]
        # avgpool
        self.avgpool = nn.AvgPool2d(kernel_size=2)
        # output layer
        self.linear = nn.Linear(512, 10)
        # dropout
        self.dropout = nn.Dropout2d(p=0.15)

    def forward(self, x):
        x = self.encode_a(x)
        x = self.encode_b(x)
        x = self.encode_c(x)
        x = self.encode_d(x)
        x = self.encode_e(x)
        x = self.avgpool(x)
        out = self.linear(x.squeeze())
        
        return out


if args.gpu == 99:
	device = torch.device("cuda:0")
else:
	torch.cuda.set_device(args.gpu)
	device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')

transforms = tforms.Compose([tforms.RandomHorizontalFlip(p=0.5),
	                         tforms.RandomCrop(size=32, padding=4),
	                         tforms.ToTensor(),
	                         tforms.Normalize((0.5, 0.5, 0.5),
	                             	          (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transforms)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=tforms.Compose([tforms.ToTensor(),
                                                                                tforms.Normalize(
                                                                                    (0.5,0.5,0.5),
                                                                                    (0.5,0.5,0.5))]))
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                         shuffle=False, num_workers=2)


print('Using {} labeled  data'.format(len(trainset)))
print('Using {} validation data'.format(len(testset)))

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = ResNet()
if args.gpu == 99:
     net = nn.DataParallel(net)
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=args.lr)

def train():
    net.train(True)
    running_loss = 0.

    for i, data in enumerate(trainloader):
        imgs, labels = data

        optimizer.zero_grad()

        preds = net(imgs.to(device))
        loss = criterion(preds, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / (len(trainloader.dataset) / args.batch_size)
    print('Avg Train Loss: {:.4}'.format(avg_loss))
    return avg_loss

def valid():
    net.eval()

    running_loss = 0.
    running_corr = 0.
    total = 0.

    with torch.no_grad():
        for data in testloader:
            imgs, labels = data

            preds = net(imgs.to(device))

            _, predicted = torch.max(preds.data, 1)
            total += labels.size(0)
            running_corr += (predicted == labels.to(device)).sum().item()

            running_loss += criterion(preds, labels.to(device)).item()

    avg_loss = running_loss / (len(testloader.dataset) / args.batch_size)
    acc = running_corr / total
    print('Avg Valid Loss: {:.4}, Valid Acc: {:.4}'.format(avg_loss, acc))
    return avg_loss, acc

train_losses = []
valid_losses = []
valid_accs = []
try:
    for e in range(args.epochs):
        start = time.time()
        print('\n' + 'Epoch {}/{}'.format(e, args.epochs))
        train_loss = train()
        valid_loss, valid_acc = valid()

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_accs.append(valid_acc)
        end = time.time() - start
        print('{} seconds'.format(end))

except KeyboardInterrupt:
    pass

import pandas as pd

out_dict = {'train_losses': train_losses, 
            'valid_losses': valid_losses,
            'valid_accs': valid_accs}

out_log = pd.DataFrame(out_dict)
print(out_log.head())

out_log.to_csv('cifar10_resnet_gpu-{}.csv'.format(args.gpu), index=False)

I have the same issue. I get NaN losses when I validate my model using dataparallel, but losses during training are calculated fine.

Did you manage to fix this problem?

Could you check, if your validation data contains invalid values (inf or NaN)?

I checked using torch.sum(torch.isnan(val_pred)), and there are no nan or inf values in my validation data (predicted or ground truth masks).

This only seems to be a multi-gpu problem. When using a single GPU this problem doesn’t happen.