Proper way to use torch.nn.CTCloss

I am using CTC in an LSTM-OCR setup and was previously using a CPU implementation (from here). I am now looking to using the CTCloss function in pytorch, however I have some issues making it work properly. My test model is very simple and consists of a single BI-LSTM layer followed by a single linear layer.

 def make_model(ninput=48, noutput=97):
     return nn.Sequential(
        # run 1D LSTM
        layer.Lstm1(100),
        # reorder for Linear layer
        layer.Reorder("BDW", "BWD"),
        # run single linear layer
        layer.Linear(noutput),
        # reorder to CTC convention
        layer.Reorder("BWD", "WBD"))

My inputs are image batches in the form “BDW” (batch, depth, width). My targets are of the form “BL”

target[0] = tensor([18, 50, 43, 61, 39, 52, 42, 43, 56])

with the numbers going from 1 to C, reserving 0 for “blank”.
I then train the model like this:

def train(model, input, target, input_lengths, target_lengths):
    assert input.size(0) == target.size(0)
    logits = model.forward(input)
    probs = nn.functional.log_softmax(logits, dim=2)
    optimizer.zero_grad()
    loss = ctc(probs, target, input_lengths, target_lengths)
    loss.backward() 
    optimizer.step()
    return nn.functional.softmax(logits, dim=2)

For the optimizer I use SGD.
When training using my data set, it only predicts one letter in the beginning, but after a couple of epochs it only predicts blank for all inputs. If I only use a single sample for training and the one letter predicted in the beginning is part of the target, it will increase the probability for that output to 1 for any input, instead of predicting blank.
So far I am using a batch size of 1, because I have additional problems with how to provide the data for larger batches. If I provide the input as a “BDW” tensor, where “W” is the maximum input_length for all samples in the batch, zero-pad all other samples to the same length and provide the correct input_lengths, the model produces “NaN” after a few epochs.

I had reasonable outputs using the CTC implementation mentioned in the beginning, although it was a lot slower, so I assume I am using it somehow incorrectly.

UPDATE: I at least figured out why there didnt seem to be any training going on. I am not sure how pytorch scales the CTC loss, but the updates were just so much smaller compared to the implementation I used previously, that training stopped too early. Increasing the learning rate I noticed that training is happening.

1 Like

Set ‘reduction=None’ on the loss otherwise it is averaged across all time steps resulting in reaaally slow training.

1 Like

Reduce provides the list of losses per sequence in my batch.
How do you propose I use those, if not by averaging them, by using reduction=‘mean’?

reduction='mean' will also average over lengths, so by using reduction='sum' or reduction='none' and taking the mean only over the batch dimension, you’ll get a higher gradient.
That said, for Adam, it should cancel with the implied gradient weighting, and for SGD you could use a higher learning rate, too.

Best regards

Thomas

1 Like

I also encountered a similar problem (i.e. only predicting blank). Additionally I found that a nearly perfect prediction has higher loss than the predict_all_blank . (I also pre-pad the ylabels with blank). Here’s the setup for replication, I’m wondering if I’ve used it properly or there might be a bug? Please let me know if additional info is needed. Really appreciate it!

Here are the observations, note loss from perfect prediction is higher than that from all_blank, with/without reduction='none'. Additionally see bottom for experiment setup:

# print loss
tloss = torch.nn.CTCLoss(blank=79, zero_infinity=False, reduction='none')
print('Perfect prediction:\n', tloss(pred_perf,  batch_y_cat, inputls, outputls))
print('Model prediction:\n',tloss(pred_model,  batch_y_cat, inputls, outputls))

# output
# Perfect prediction:
# tensor([110.0361, 109.6828, 107.2605], device='cuda:0')
# Model prediction:
# tensor([86.3294, 90.4917, 38.5629], device='cuda:0')
# print the predicted raw results
tloss = torch.nn.CTCLoss(blank=79, zero_infinity=False, reduction='mean')
for idx in range(3):
    print('=========================================')
    print('Prediction - perfect prediction')
    print(pred_perf.argmax(dim=2).permute((1,0))[idx])
    print("loss:", tloss(pred_perf[:,idx, :].unsqueeze(1),  batch_y[idx].unsqueeze(0), inputls[idx], outputls[idx]))
    print('--------')
    print('Prediction - model')
    print(pred_model.argmax(dim=2).permute(1,0)[idx])
    print("loss:", tloss(pred_model[:,idx, :].unsqueeze(1),  batch_y[idx].unsqueeze(0), inputls[idx], outputls[idx]))
    print('--------')
    print('Ground Truth')
    print(batch_y[idx])
    print('Unpadded ground truth')
    unpad_y = batch_y[idx][: outputls[idx]]
    print(unpad_y)
# output
# =========================================
# Prediction - perfect prediction
# tensor([55, 43, 40, 62, 41, 44, 53, 40, 62, 41, 50, 53, 62, 58, 36, 54, 43, 44,
#         49, 42, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79], device='cuda:0')
# loss: tensor(30.2797, device='cuda:0')
# --------
# Prediction - model
# tensor([79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79], device='cuda:0')
# loss: tensor(3.2668, device='cuda:0')
# --------
# Ground Truth
# tensor([55, 43, 40, 62, 41, 44, 53, 40, 62, 41, 50, 53, 62, 58, 36, 54, 43, 44,
#         49, 42, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79],
#        device='cuda:0')
# Unpadded ground truth
# tensor([55, 43, 40, 62, 41, 44, 53, 40, 62, 41, 50, 53, 62, 58, 36, 54, 43, 44,
#         49, 42], device='cuda:0')
# =========================================
# Prediction - perfect prediction
# tensor([42, 50, 62, 48, 56, 38, 43, 62, 41, 56, 53, 55, 43, 40, 53, 62, 55, 43,
#         36, 49, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79], device='cuda:0')
# loss: tensor(30.3025, device='cuda:0')
# --------
# Prediction - model
# tensor([79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79], device='cuda:0')
# loss: tensor(3.4614, device='cuda:0')
# --------
# Ground Truth
# tensor([42, 50, 62, 48, 56, 38, 43, 62, 41, 56, 53, 55, 43, 40, 53, 62, 55, 43,
#         36, 49, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79],
#        device='cuda:0')
# Unpadded ground truth
# tensor([42, 50, 62, 48, 56, 38, 43, 62, 41, 56, 53, 55, 43, 40, 53, 62, 55, 43,
#         36, 49], device='cuda:0')
# =========================================
# Prediction - perfect prediction
# tensor([36, 62, 38, 50, 48, 51, 47, 40, 55, 40, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79], device='cuda:0')
# loss: tensor(61.9394, device='cuda:0')
# --------
# Prediction - model
# tensor([79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79], device='cuda:0')
# loss: tensor(3.8248, device='cuda:0')
# --------
# Ground Truth
# tensor([36, 62, 38, 50, 48, 51, 47, 40, 55, 40, 79, 79, 79, 79, 79, 79, 79, 79,
#         79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79],
#        device='cuda:0')
# Unpadded ground truth
# tensor([36, 62, 38, 50, 48, 51, 47, 40, 55, 40], device='cuda:0')
# create perfect prediction and model prediction

# key parameters
time_steps = 189
n_class = 80
blank_idx = 79

# construct the perfect prediction based on ground truth
eps = 0.0001
B, curr_t = batch_y.shape
batch_y = batch_y.cpu()
m_fit = torch.cat([torch.zeros(time_steps-curr_t, blank_idx), torch.ones(time_steps-curr_t, 1), torch.zeros(time_steps-curr_t, n_class-1-blank_idx)], dim=1)
pred_perf_prob = torch.stack([(torch.cat([torch.eye(n_class)[batch_y[i]], m_fit], dim=0)*(1-eps*n_class)+eps) for i in range(len(batch_y))], dim=0).to(device) # (B, T, n_class)
pred_perf = torch.nn.functional.log_softmax(pred_perf_prob, dim=2).permute((1,0,2)) # (T, B, n_class)

# get model prediction, which predicts all blank
with torch.no_grad():
    pred_model, inputs = model.network(batch_x) # (T, B, n_class)

inputls = torch.full(size=(B,), fill_value=time_steps, dtype=torch.long).to(device)
outputls = (torch.sum(batch_y != 79, dim=1)).to(torch.long).to(device) #tensor([20, 20, 10])

Env: PyTorch 1.1, CUDA 9

A follow-up with additional observation and minimal code for repro: create a perfect prediction from one-hot-encoding from ylabel and a all blank. Both are a batch of 1 datum. Weirdly, the loss of perfect prediction is higher than all_blank when input_length is high, but lower than all_blank when input_length is low. Is it expected?

import torch
import torch.nn.functional as F

T = 189
n_class = 80
y = torch.tensor([[55, 43, 40, 62, 41, 44, 53, 40, 62, 41, 50, 53, 62, 58, 36, 54, 43, 44, 49, 42]])
output_length = torch.tensor(y.shape[1])

pred_model_idx = 79*torch.ones(T, dtype=torch.long)
pred_perf_idx = torch.cat([y[0], (n_class-1) * torch.ones(T-y.shape[1], dtype=torch.long)]) # the first idx are perfect with y, then padded with blanks
pred_model = torch.eye(n_class)[pred_model_idx].unsqueeze(1) # one-hot encoding
pred_perf = torch.eye(n_class)[pred_perf_idx].unsqueeze(1) # one-hot encoding

for input_length in [torch.tensor(y.shape[1]), torch.tensor(T)]:
    print("=============\ninput length:", input_length)
    print("perfect loss:", F.ctc_loss(F.log_softmax(pred_perf, dim=2), y, input_length, output_length, n_class-1, 'none', True))
    print("all_blank loss:", F.ctc_loss(F.log_softmax(pred_model, dim=2), y, input_length, output_length, n_class-1, 'none', True))

# OUTPUT
# =============
# input length: tensor(20)
# perfect loss: tensor([68.0656])
# all_blank loss: tensor([88.0655])
# =============
# input length: tensor(189)
# perfect loss: tensor([605.4802])
# all_blank loss: tensor([593.8168])

To be honest, neither calling it “perfect prediction” (note that the log_softmax result will assign probability mass (log) to non-target classes) nor changing the length in the way you do makes much sense to me. Could just be me, though.

Here is an observation: if your input is longer than your targets the “aligned paths” over which you take probabilities will necessarily include blanks. So if you need enough blanks, assigning a high probability to them will reduce the loss.

Best regards

Thomas

Thanks Thomas, really appreciate your reply! Probably I didn’t explain myself well: I assume perfect_prediction is more similar to ylabel than predict blank, since perfect_prediction contains a prediction path that is exactly the same as ylabel, while predict_blank will only be reconstructed as a empty sequence. If this is right, then why perfect_prediction will ever have a higher loss than predict_blank? (The former should always have lower loss, no?) The experiment on the changing length is less relevant, I’m just surprised this behavior is not monotonic and is related to input_length.

BTW, do you mind explaining ‘So if you need enough blanks, assigning a high probability to them will reduce the loss’? Maybe I’ve missed some part.

The context here is: I’m training an OCR model with CTCLoss, and during training I can see the loss goes down but the model just keep predicting all blanks (seems other people also have similar observation), and I’m not sure where the bug is, which leads to my above loss comparison.

1 Like

Hi, have you solve the problem.

I’m having a similar problem.
It seems some data give me extremely high loss score even when the inference result is perfect. I worry if nn.CTCLoss is even useful.

I want to know how to avoid the problem if possible.

Note that in the above example the inference result is not “perfect” because you need to pass log-softmaxed activations into the ctc loss. This means that the ideal case would be an alignment where all targets except the true target have -inf and the true target has 0.

1 Like

Thank. I got it.

It seems I got big loss because probability of each time steps is relatively small even though they were biggest in the time steps. I don’t know which works better in real world use case between small loss ( predict something extremely confidently ) and relatively big loss ( predict same thing but not for so sure ).