CTC loss with variable input_lengths produces NaN values

Hello all,

I am working with Pytorch 1.0.1 and I am using the CTC loss of PyTorch.

The code looks like the following one (I am working on GPU):

criterion = torch.nn.CTCLoss()
outs, (h,c) = lstm(input) # input is padded with zeros
outs = torch.nn.functional.log_softmax(outs, dim=2)
loss = criterion(outs.permute(1,0,2).contiguous(), y.to(device), lengths_input.to(device), lengths_target.to(device) )

I am permuting outs (FloatTensor) because I am working batch_first (N, T, C) while CTC requires (T, N, C).
y contains the concatenation of the labels for each batch and is a LongTensor.
lengths_input and lengths_target (LongTensors) have 1 dimension with N elements each representing the length of each input sequence in the batch and each target sequence in the labels.
I am working with variable length sequences so the values of lengths_input may be different (and they often are) from one another.
This setting causes the loss to become nan after few iterations.

However, if I fill lengths_input with all equal sizes (i.e. outs.size(1)) the loss decreases fine.
This is the only change I made from one version to the other.

I was previously using PyTorch 1.0.0, then I read about this fixed bug and I updated to 1.0.1 but nothing changed.

As I wrote in the comment I am using padded inputs to obtain a single batch-tensor. Is this not compatible with variable sized lengths in the CTC loss? Am I missing something?

Thank you!

Ops, I found this in the CTC docs:

In order to use CuDNN, the following must be satisfied: targets must be in concatenated format, all input_lengths must be T. blank=0, target_lengths ≤256, the integer arguments must be of dtype torch.int32.

The regular implementation uses the (more common in PyTorch) torch.long dtype.

So I guess I cannot use different values for input_lengths in CUDA, I will try on CPU. Anyway even with fixed lengths I was using LongTensor instead of torch.int32 and no errors were raised. I don’t know if the result is still correct in this case.

No that is for CuDNN, you’ll be OK if you use the “native” (PyTorch) version.
Two caveats

  • If you feed “invalid” (targets to long for input length) samples, you probably want the zero_infinity option from PyTorch nightlies. This will very likely be the source of your difficulties.
  • GPU CTC doesn’t like zero length targets currently (sorry, it’s on my to-do list).

Best regards

Thomas

1 Like

Thank you for your reply!

I don’t have any zero length targets so this should not be a problem.

I did not understand exactly what do you mean by “invalid” sample: in my example above it happens if lengths_target[i] > lengths_input[i] for some i?

I will check the zero_infinity option, thank you!

The precise condition is target length + repeats > input length. This has likelihood 0 and so the negative log likelihood (that ctc loss is) is infinite. Zeroing that will help there, but actually distort your score of used in validation.

Best regards

Thomas

3 Likes

I see, this surely happens in my case,

thank you again.

Just to keep you updated about this: I used the nightly version with zero_infinity flag set to true but I am still getting nan from a certain time step on.

What seems strange to me is that if I use a constant lengths_input (torch.full) this never happens.
Of course in this way the length of the input stays the same or it is increased so I reduce the number of invalid samples, but by zeroing the inf loss I should get the same result.

I will continue to try to fix it and eventually let you know.

If you can isolate the inputs that cause CTC or gradients to go wrong (before the inputs become NaN) I’d be most interested. So far no-one did that and it almost seems like for most people seeing NaNs the source is somewhere else.

Best regards

Thomas

Ok, the problem is gone and the loss is decreasing fine. Maybe on the very first run after your suggestion I mistakenly executed the old script, I don’t know.
Anyway on the subsequent runs I never faced the problem anymore, thank you!

I’ll just share my observations here. I am also seeing this NaN gradient issue in my training code. I tried to capture the inputs to CTC when gradients are NaN using backward hooks. The code is like

class CTCWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.ctc = nn.CTCLoss(blank=0, reduction='none', zero_infinity=True)

    def forward(self, logits, logit_lengths, labels, label_lengths):
        def _backward_hook(grad):
            if not np.isfinite(grad.sum().item()):
                # save logits, logit_lengths, labels, label_lengths and grad
                torch.save(...) 
        log_probs = F.log_softmax(logits, 2)
        log_probs.register_hook(_backward_hook)
        return self.ctc(log_probs, labels, logit_lengths, label_lengths)

The variables are dumped successfully when NaN gradient is detected. The dumped grad variable (I’ll call it dump_grad) has NaN values while others (logits, labels) don’t. However, when I tried those inputs on CTCWrapper alone, the backward gradients (I’ll call it offline_grad) have no NaN values.

I further looked into dump_grad and offline_grad, they are only different where dump_grad is NaN. dump_grad is a [396, 93, 96] tensor (T,N,C layout). It has four NaNs in four different samples, all of them at the last time step and in the blank label, i.e. only dump_grad[-1,:,0], has NaN.

I don’t know where to go further so I have to stop here and use a workaround, like setting NaNs to zeros.

I’m using PyTorch 1.1 stable and CUDA 10.0.

The values that contain NaNs are attached. I can upload the dump if anyone is interested.

>> dump_grad[-1,:,0]

tensor([-1.1732e-05,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,         nan,  1.8528e-04,  1.8528e-04,
         1.8528e-04,         nan,         nan,         nan,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04], device='cuda:0')
>> offline_grad[-1,:,0]

tensor([-1.1732e-05,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,  1.8528e-04,
         1.8528e-04,  1.8528e-04,  1.8528e-04], device='cuda:0')

There was a bug when there was uninitialized memory in unused parts of the input gradient. This is still in PyTorch releases including 1.1. Thanks to @yqwangustc’s fix the nightlies / master would work better.

Best regards

Thomas

1 Like