I’ve been trying to use
CTCLoss for one of my projects. For unknown reasons, I couldn’t get it to work, and it would always give me
0 if I set
I decided to play around with the example included in the documentation of nn.CTCLoss and see if I can break it.
Turns out, for some inputs this function returns
inf. See the example below. I don’t know enough about the internals of the CTC algorithm to know why such inputs yield
inf (tested with PyTorch 1.5.1):
torch.manual_seed(4) T = 375 # Input sequence length C = 6 # Number of classes (including blank) N = 1 # Batch size # Initialize random batch of input vectors, for *size = (T,N,C) input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) # Initialize random batch of targets (0 = blank, 1:C = classes) target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long) ctc_loss = nn.CTCLoss() loss = ctc_loss(input, target, input_lengths, target_lengths) print(loss) # prints "inf"
I wanted to open a bug on GitHub, but I didn’t want to waste anybody’s time in case I was doing something stupid. Any ideas what the issue could be?