Nan Loss with torch.cuda.amp and CrossEntropyLoss

I am trying to train a DDP model (one GPU per process, but I’ve added the with autocast(enabled=args.use_mp): to model forward just in case) with mixed precision using torch.cuda.amp with train_bert function.
The model trains fine without amp as well as with autocast(enabled=False). When I try running it with mixed precision (args.use_mp = True), I get nan loss after first iteration.
I used autograd.detect_anomaly() to find that nan occurs in CrossEntropyLoss: RuntimeError: Function ‘LogSoftmaxBackward’ returned nan values in its 0th output. Not sure what kind of mistake am I looking for.
Below I insert the code for the training function and to the LossClass (criterion):

def train_bert(rank, args, epoch,
               model, optimizer, criterion, loss_components,
               train_generator, phase, scaler=None):
    """
    bert training on args.bert_steps batches
    :param rank: process rank
    :param args: config class
    :param epoch: epoch number for logging
    :param model: bert model
    :param optimizer: bert optimizer
    :param criterion: callable loss
    :param loss_components: torch.Tensor with cumulative losses for each objective
    :param train_generator: train data loader (chunk wise)
    :param phase: phase for phase wise training
    :param scaler: scaler for mixed precision training (used if args.use_apex=True)
    :return: mean step loss on the part of the training set
    """
    model.train()
    optimizer.zero_grad()
    nb_tr_steps = phase * args.bert_steps
    time_dict = {'step_time': 0.0, 'forward_time': 0.0, 'backward_time': 0.0}

    for step, batch in islice(enumerate(tqdm(train_generator)), phase * args.bert_steps, (phase + 1) * args.bert_steps):
        inp, lbl, meta = batch
        X_ids, X_type, X_attn = (X.cuda(rank) for X in inp)
        lbl = lbl.cuda(rank, non_blocking=True)
        start = time.perf_counter()
        with autocast(enabled=args.use_mp):
            pred = model(X_ids, X_type, X_attn)
            time_dict['forward_time'] += time.perf_counter() - start
            components, loss = criterion(pred, lbl, args)
            loss = loss / args.grad_acum_steps
            start = time.perf_counter()

        if args.use_mp:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        time_dict['backward_time'] += time.perf_counter() - start
        loss_components += components
        nb_tr_steps += 1
        if nb_tr_steps % args.grad_acum_steps == 0:
            if args.use_mp:
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
                optimizer.step()
            optimizer.zero_grad()
    return loss_components.sum() / nb_tr_steps
class ConditionalLoss:
    def __init__(self, args, rank):
        self.bce = nn.BCEWithLogitsLoss().cuda(rank)
        self.ce = nn.CrossEntropyLoss(reduction='none').cuda(rank)
        self.loss_weights = args.loss_weights

    def _binary_loss(self, pred, lbl):
        return self.bce(pred, lbl.unsqueeze(1))

    def _start_end_loss(self, pred, lbl, is_yes_no):
        is_span = (1 - is_yes_no)
        return (is_span * (self.ce(pred[0], lbl[0]) + self.ce(pred[1], lbl[1]))).mean()

    def __call__(self, preds, labels, args):
        pred_is_yes_no, pred_span = preds
        is_yes_no, yes_no, span = unpack_lbls(labels)
        s_w, yn_w, iyn_w = self.loss_weights

        is_y_n_loss = iyn_w * self._binary_loss(pred_is_yes_no, is_yes_no)
        span_loss = s_w * self._start_end_loss(pred_span, span, is_yes_no)

        return torch.Tensor([span_loss.item(), is_y_n_loss.item()]), is_y_n_loss + span_loss
1 Like

@mcarilli, would you mind giving me some directions for debugging?

Based on your code snippet you are clipping the gradients in FP32 training, but not in the amp run.
Could you add it as described here to get matching behavior and check, if the loss still gets NaN values?

Thank you for the advice. I’ve added the gradient clipping as you suggested, but the loss is still nan. The value in args.clip_grad is really large though, so I don’t think it is doing anything, either way, just a simple way to catch huge gradients. But I agree it should be in both conditions for consistency.

        if nb_tr_steps % args.grad_acum_steps == 0:
            if args.use_apex:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
                optimizer.step()
            optimizer.zero_grad()

Is there anything else I could try?

You could check with forward hooks where the first invalid output is created to narrow down the issue.
Here is an example of using forward hooks to get the intermediate activations.

I would start by checking the model output first. If it’s valid, this would point to the custom loss function, which might create e.g. overflows.

Thank you for your advice, I was able to locate the overflow.
nans appear in the 22/24 layer of my BERT model.
Since I do not control where exactly mixed precision is applied, what can I do to overcome the issue?

Thanks for the test! Could you link to the Bert model you are using or a repository, where I could try to reproduce this issue?

I am using a pretrained SpanBERT and loading it with transformers:

from transformers import BertConfig, BertModel


config = BertConfig.from_pretrained(args.model_path)
self.bert_qa_encoder = BertModel.from_pretrained(args.model_path, config=config)

The amp does not show nans with the base model, only with the large one.
Is it possible that the pretrained weights are already overflowing the fp16? Then is weight clipping before training the solution?

This could be the case, as other users report similar issues for this model here.
I’ll check with our research team and see, if they’ve already used this particular model.

1 Like

Hi, sorry to bother
Any news from the research team?

I haven’t got a response for SpanBERT yet.
However, based on this comment:

Are you seeing NaN outputs in the forward pass only (without training) after loading the pretrained weights for the “large” model?
If so, are you seeing the same invalid outputs, if you create a new, randomly initialized, model?

Hi,
I’m encountering a similar issue.
I’m getting NaN after backproping the first iteration.
The model uses multiple losses.
But, I narrowed down the problem to kl_divergence.
When I’m backpropping torch.mean(self.kl_divergence(dist0,dist1)), I’m getting the following message from the anomaly detection:-

File “/home/schatter/Soumick/Code/DS6/pipeline.py”, line 262, in train
self.scaler.scale(floss[i]).backward(retain_graph=True)
File “/home/schatter/anaconda3/envs/torchMRI17/lib/python3.8/site-packages/torch/tensor.py”, line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File “/home/schatter/anaconda3/envs/torchMRI17/lib/python3.8/site-packages/torch/autograd/init.py”, line 130, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function ‘CudnnConvolutionBackward’ returned nan values in its 0th output.

Without AMP, its working fine.
Its only showing up when I’m using AMP.
Any ideas?
Thanks.

Bit more info from my debug:
The module where the distributions are generated from, if I put the forward pass of only that module with autocast False, then it works for 14 iterations and then throughs the exact same error.
Here is the line where the distribution is generated:-
dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1)

This points towards an instability on your overall training and doesn’t seem to be related to AMP.
I’m not sure how “Its only showing up when I’m using AMP.” fits into the new post.

For further debugging I would suggest to disable amp for now and check the inputs to the criterion, which create the invalid outputs/gradients.

I already trained without issues by disabling amp for the whole training. But with AMP I got that error. When I’m using AMP all-over, except for that part I mentioned then I’m getting error after 14 iterations. When I’m using AMP all-over, then in the very first iteration. And I’m not using AMP at all, then working fine. So, I was wondering if either normal distribution or kld has any possible connection with AMP behind this problem?

In my case, I think nan is caused by the loss is too large to be held by float16.

Same issue.But I think I’ve resolve it. When we use loss function like ,Focal Loss or Cross Entropy which have log() , some dimensions of input tensor may be a very small number. It’s a number bigger than zero , when dtype = float32. But amp will make the dtype change to float32. If we check these dimensions , we will find they are [0.]. So as the input of log(), we will get NaN. There are two ways to solve the promblem:

  1. add a small number in log ,like 1e-3. The price is the loss of precision
  2. make the dypte of the input of log() be float32
    e.g.: yhat = torch.sigmoid(input).type(torch.float32)
    loss = -y*((1-yhat) ** self.gamma) * torch.log(yhat + 1e-20) - (1-y) * (yhat ** self.gamma) * torch.log(1-yhat + 1e-20)`
4 Likes

I’m listing here a few things that I found mentioned in connection with the issue. For context, I was also training a(n LSTM-) model with AMP + DDP. These, most of which are brought up in this issue, helped to stabilize my model:

The instability, however, persisted and the problem was solved by changing the model architecture. More specifically, there was an overflow in one of the BN-layers’ running variance: the fix was to clip the max value of the input tensors before forwarding to the BN-layer, e.g.

...
x = self.relu(x)
x = torch.clamp(x, max=10.)
x = self.bn(x)
...

Since the clamping was done right after the ReLU (later Mish)-activation, it essentially resulted in clipped ReLU.

It turned out that with AMP disabled the problem was there also, but didn’t ever cause the NaNs/Infs to appear.

1 Like

I have used CIFAR-10 dataset and VGG-19 architecture. It is working fine with the default FP-32 training. But, when I use mixed precision, I am getting the loss as nan

t1 = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])
train_dataset = torchvision.datasets.CIFAR10(root='datsets/',train=True,transform=t1,download=True)
train_loader = DataLoader(dataset = train_dataset,shuffle=True,batch_size=16)
model = models.vgg19(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = 4e-3)
num_epochs = 10
device = 'cuda'
model = model.to(device)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader),total=len(train_loader))
    loss = 0
    for input, target in train_loader:
        input = input.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(input)
            loss =  criterion(output, target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loop.set_description(f'Epoch [{epoch+1}]')
        loop.set_postfix(loss = loss.item())

Is there any fix???