Best practices to solve NaN CTC loss

I’ve encountered the CTC loss going NaN several times, and I believe there are many people facing this problem from time to time. So perhaps a collective list of best practices would be helpful for this.

Best Practices for Avoiding NaN CTC

1. Mismatch between model’s number of classes and class_ids in labels

A common problem is that, seeing the largest class in our label_list is C, we mistakenly set the model’s number of classes also to C. Instead, it should be C + 1, because class 0 is usually reserved for BLANK or PAD token.

Suppose you have classes 1 to 100 in your labels, and the model also has number_of_classes = 100. Whenever the CTC loss encounters a label tensor that has 100 in it, and model predictions of shape (T, B, 100), it would silently go bonkers, as it actually needs (T, B, 101).

####################
# Check 1
####################

max_class_id = 0
for label_sequence in list_of_all_labels:
    for class_id in label_sequence:
        max_class_id = max(max_class_id, class_id)
assert max_class_id == number_of_model_classes - 1

2. Check your label classes

####################
# Check 2
####################

for label_sequence in list_of_all_labels:
    for class_id in label_sequence:
        assert  class_id > 0 and class_id < number_of_model_classes

You have to consider what to do if you have labels out of your class_list. Either drop those samples, or have a specific class UNKNOWN (I usually have it as class 1) where you assign unknown classes.

Check 2 can easily be done while doing check 1; I’m only writing them separately for emphasis.

3. Check your label lengths

####################
# Check 3
####################

min_expected_length = 1

for label_sequence in list_of_all_labels:
    assert  len(label_sequence) < model_prediction_sequence_length
    assert  len(label_sequence) >= min_expected_length

Target length should not exceed model’s prediction’s sequence length. Also blank labels and too short labels cause problems.

Drop all samples with too short/blank labels. Either drop anomalous samples which have really long or labels beyond some threshold, or try to increase model’s sequence length so that it is guaranteed to be longer than your labels.

I’ve found model sequence length = double of max_label_length to work well empirically; would like to hear other opinions on this.

4. Catch CTC problem right at the start

In your training script, insert the following line somewhere at the start:

torch.autograd.set_detect_anomaly(True)

Will immediately halt training and point to the first place where something went wrong.

5. Set the zero_infinity parameter

In nn.CTCLoss and nn.functional.ctc_loss, there is a parameter zero_infinity which is False by default. Use nn.CTCLoss(..., zero_infinity=True).

Sometimes loss first becomes inf before NaN, in which case the inf loss (and gradients) can be reset to zero.

6. Gradient Clipping

In your training step, clip the gradient norm to some value. I read somewhere that a good value lies in the closed range [0.5, 5.0], but haven’t verified this; I usually clip to 1.0 and it works fine.

optimizer.zero_grad()
preds = model(data)
loss = criterion(preds, targets)
loss.backward()

# clip here
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # or some other value

optimizer.step()

7. Low Learning Rate

It may be possible that a seemingly normal learning rate (like 0.01) may be too large for certain problems.

Generally, a combination of AdamW, with learning rate = 0.0003, weight_decay = 0.01, warmup scheduling for first 1000 steps, followed by cosine annealing scheduling for the remaining training period is a good initial combination when tackling a new problem / project. After the initial 3e-4 learning rate, I explore by a factor of 3 in either direction, like 0.001 and 0.0001. The trio of (0.001, 0.0003, 0.0001) usually gives me a sense of what type of learning rates work. For some problems where more regularization seems necessary, I also search for weight decay among (0.01, 0.03, 0.1).

Other Useful CTC Tips

1. Handling variable length data and labels

Problem: Each __getitem__ call in your dataset produces labels (and perhaps data) of variable length. To load them with your dataloader, and also to process with your model, you most likely need them to be of the same sequence length.

Solution: The obvious way to do this is to pad your data and targets to either a constant maximum length, or to the length of the longest sequence in the batch. You can do this inside the collate function (collate_fn parameter in dataloader; a good discussion here).

Example for handwritten text recognition:

def collate_func(batch: list) -> Tuple[torch.Tensor, torch.Tensor]:
    """Custom collate function, that zero-pads labels to the size of longest label in the batch.

    Args:
        batch (list): List of (image, label) pairs.

    Returns:
        torch.Tensor: Image tensor of shape (b, c, h, w).
        torch.Tensor: Label tensor of shape (b, max_seq_len_in_batch)
    """
    images, labels = map(list, zip(*batch))
    images = torch.stack(images)

    labels = nn.utils.rnn.pad_sequence(
        sequences=labels,
        batch_first=True,
        padding_value=0,
    )

    return images, labels

I only have to handle labels here. In other tasks like video recognition, it may be necessary to pad image_sequences too (probably better idea to that inside __getitem__, I think).

2. Handling lengths & logsoftmax in CTC loss

The CTC loss also requires input lengths and label lengths. It also expects inputs (model predictions) to be logsoftmax-ed first, which sometimes can be overlooked, since typically the last layer of a model is nn.Linear, and also as we don’t need to do this when using common losses like nn.CrossEntropyLoss. A full example for the loss I use, again for text recognition:

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange


class CTCLoss(nn.Module):
    """Convenient wrapper for CTCLoss that handles log_softmax and taking input/target lengths."""

    def __init__(self, blank: int = 0) -> None:
        """Init method.

        Args:
            blank (int, optional): Blank token. Defaults to 0.
        """
        super().__init__()
        self.blank = blank

    def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """Forward method.

        Args:
            preds (torch.Tensor): Model predictions. Tensor of shape (batch, sequence_length, num_classes), or (N, T, C).
            targets (torch.Tensor): Target tensor of shape (batch, max_seq_length). max_seq_length may vary
                per batch.

        Returns:
            torch.Tensor: Loss scalar.
        """
        preds = preds.log_softmax(-1)
        batch, seq_len, classes = preds.shape
        preds = rearrange(preds, "n t c -> t n c") # since ctc_loss needs (T, N, C) inputs
        # equiv. to preds = preds.permute(1, 0, 2), if you don't use einops

        pred_lengths = torch.full(size=(batch,), fill_value=seq_len, dtype=torch.long)
        target_lengths = torch.count_nonzero(targets, axis=1)

        return F.ctc_loss(preds, targets, pred_lengths, target_lengths, blank=self.blank, zero_infinity=True)

Models usually output predictions of fixed sequence lengths, so prediction lengths can be created with torch.full. And target lengths can be extracted from the padded target tensor by torch.count_nonzero, given that we used 0 to pad them in the first place.

Common Mysterious Behavior

Strange phenomenon which happen when using CTC, and solutions (if any):

  • NaN loss occurs within a few steps of training, but if the batch_size is set to 1, NaN loss no longer occurs.

    • Because of the nature of this occurrence, it seems that there may be some issue with padding the targets, at first glance. However, I have personally encountered this recently, and the underlying problem eventually turned out to be case 1: model had 1 less class than necessary. Would be great if we could figure out why NaN doesn’t show up with batch_size 1 in such cases.
  • NaN loss occurs during GPU training, but if CPU is used it doesn’t happen, strangely enough.

    • This most likely happened only in old versions of torch, due to some bug. But would like to know if this phenomenon is still around.
  • Model only predicts blanks at the start, but later starts working normally

    • Is this behavior normal?
  • Model works normally for a while, but eventually keeps predicting blanks

    • Don’t know why this happens, but usually goes away when other issues are resolved.

Closing Remarks

  • Please correct me if I wrote anything incorrect! Some of the things I’ve written here are from empirical evidence and hearsay, and not backed by any theory. Would love to know more about why X happens.
  • If there are other best practices to do/try when facing NaN issues / other solutions which work, would be great if you could share it here. Or you could simply link to an existing great post.
  • Other strange things which are being encountered with CTC, and possible underlying reasons (theoretically proved or empirically discovered).

Some other useful threads: [a, b, c, d]

5 Likes