Label smoothing with CTCLoss

Does anybody know how to implement label smoothing (LS) with CTCLoss? I found a lot of articles about CrossEntropyLoss with Label smoothing, but nothing about CTCLoss.
I’ve found a paper ( mentioning about the LS with CTCLoss, with an equivalent approach with KLDiv, but have no idea how to implement it with PyTorch. TIA!

I implemented the CTCLoss with Label smoothing like:

class SmoothCTCLoss(_Loss):

    def __init__(self, num_classes, blank=0, weight=0.01):
        self.weight = weight
        self.num_classes = num_classes

        self.ctc = nn.CTCLoss(reduction='mean', blank=blank, zero_infinity=True)
        self.kldiv = nn.KLDivLoss(reduction='batchmean')

    def forward(self, log_probs, targets, input_lengths, target_lengths):
        ctc_loss = self.ctc(log_probs, targets, input_lengths, target_lengths)

        kl_inp = log_probs.transpose(0, 1)
        kl_tar = torch.full_like(kl_inp, 1. / self.num_classes)
        kldiv_loss = self.kldiv(kl_inp, kl_tar)

        #print(ctc_loss, kldiv_loss)
        loss = (1. - self.weight) * ctc_loss + self.weight * kldiv_loss
        return loss

Can anybody confirm if this is correct or not? I also referred to this:

  1. Your ctc_loss looks straightforward, so there shouldn’t be any issues.
  2. In kl_tar:

why are you using kl_inp (which uses input log_probs) for creating the kl_tar? Should it be dependent on the targets input instead?

Thanks for your interest, @Abhilash_Srivastava. I just borrow the kl_inp shape to generate kl_tar tensor, with a constant value of 1. / self.num_classes.

Ohkk, got it!
Are you seeing any issues now?

No. I just want to make it sure my implementation is correct. It seems work as I expected, but I’m still testing it with multiple choice of label smoothing weights.

Cool! Looks fine to me.
If label smoothening is bothering you, another way to test it is to change label smoothing to 1. ie: simply use one-hot representation with KL-Divergence loss. In this case, your loss values should match exactly the Cross-Entropy loss values.

It’s good to know! Thank you for your comment!

just curious, why you want to transpose the logits as following?

    kl_inp = log_probs.transpose(0, 1)

the shape of kl_inp should be something like (BatchSize,TimeStep, Channels) or something else?