CTC Loss Unexpected Behaviour

I have used the following code to test the behaviour of CTC loss.

def get_char_maps (vocab):
    for c in vocab:
    return (char_to_index, index_to_char, vocab_size)

loss_function = CTCLoss()
char_to_index, index_to_char, vocab_size = get_char_maps(['~', 'a', 'b', 'c', 'd', 'e', 'f', 'g', ' '])

empty_char = '~'
label = 'abc'
seq_len = 20
predict = 'abc'
pred_out = predict[0]
for i in range(1, len(predict)):
    if predict[i] == predict[i - 1]:
        pred_out += empty_char + predict[i]
        pred_out += predict[i]
predict = pred_out
label_len = len(label)
left_side = (seq_len - len(predict))//2
right_side = seq_len - left_side - len(predict)
seq_predict = [empty_char]*left_side + list(predict) + [empty_char]*right_side

out = []
for char in seq_predict:
    temp = [0]*vocab_size
    temp[char_to_index[char]] = 1
out = torch.FloatTensor(out)

scores = fn.log_softmax(out, dim=2)
out_size = torch.tensor([seq_len]*1, dtype=torch.int)
y_size = torch.tensor([len(label)], dtype=torch.int)
y = [char_to_index[c] for c in label]
y_var = torch.tensor(y, dtype=torch.int)
l = loss_function(scores, y_var, out_size, y_size)

With this code, when the target is abc and the prediction is also abc, it produces a loss of 7.770. However with same target when the prediction is g, it produces a loss of 7.654. Which is lower than the perfect prediction. This behaviour seems counterintuitive.
Is this the expected behaviour from CTC loss? Or is this test code has an error?


This is expected behaviour, I would venture that your interpretation of the test code and what a prediction is in this context is not ideal.

Note that the inputs to CTC loss are always log probabilities and CTC always sums over all possible alignments with the target. In your example the “prediction” gets about 25% mass and all other characters a little over 9%.
Now taking argmax (the “greedy” prediction in the beam search terminology) gives you back the prediction, but in terms of all possible alignments, the reduced alignment for the “empty” prediction reduces the probability of the targets when taken as the sum over all possible alignments.
If you had “sharper” predictions (replace the 0s in out with strongly negative values, say -10), you would see that the “wrong prediction” loss goes up while the “correct prediciton” loss goes down.

Finally, as a quick smell-test: CTC loss is a log probability, so the “perfect score” is 0 (=log probability of the target sequence when that is 1), so 7 is much more than you would expect from a strongly overfit network. Though one generally has to take the sequence length into account, i.e. the longer it is, the less close you can get to loss 0 without very spiked predictions.

Distill.pub has a good exposition on CTC where they illustrate the role alignment plays for the CTC loss.

Best regards


1 Like