CTCLoss don't work in pytorch

Hi,

I try to use CTCLoss from https://github.com/SeanNaren/warp-ctc ,
but it always returns loss = inf/nan for large batch size. Maybe someone know how to fix it? Thanks in advance.

Can you post a minimal gist or so to reproduce?
(I.e. precompute outputs and target and just have your ctc application.)
It works for me but acts funny on invalid inputs etc.

Best regards Thomas

Yes, it small sample which should recognize mnist sequence

import cv2
import numpy as np
import torch
from torch import Tensor
from torch import nn
from torch.autograd import Variable

from warpctc_pytorch import CTCLoss

criterion = CTCLoss()

batch_size = 256
for i in range(10):
    labels = Variable(torch.from_numpy(np.random.randint(0, 10, (batch_size, 2))).int()).view(-1)
    acts = Variable(torch.randn((2, batch_size, 11)), requires_grad=True)
    act_lens = Variable(Tensor([2] * batch_size).int())
    label_lens = Variable(Tensor([2] * batch_size).int())
    loss = criterion(acts, labels, act_lens, label_lens) / batch_size
    loss.backward()
    print("loss: {}".format(loss.data[0]))

So can you please grab an example acts, labels, act_lens and label_lens and make a thing that only has 1 file and just the call to criterion? That would be much quicker to look at. Or you could print the types and shape of these and see if there is anything suspicious.

Best regards

Thomas

Ok, I maked sample is minimalistic which reproduce error.
If batch_size = 1 then it’s works, but if batch_size = 32, for example, I get loss value as inf.

With the random label code you posted above, the one problem seems to be that 0 should not be in the labels. If I change the labels to

    labels = Variable(torch.from_numpy(np.random.randint(1, 10, (batch_size, 2))).int()).view(-1)

(note the lower limit of 1 in randint), I get losses of about 5±0.3.
The 0 label is reserved for blank rnn output with the warp_ctc implementation.

Best regards

Thomas

1 Like

Thank!

It’s strange that 0 label is reserved for blank label instead of, for example max_label + 1 (with max_label as input parameter for loss function).