Manually call F.ctc_loss backward method

Sometimes one needs to manually use the gradient function, because the computed quantity is useful. For example, there is a paper that applies reweighting to CTC loss via interpreting it as cross-entropy with some distribution (it happens that CTC’s gradient computes that distribution as an intermediate step).

Is there a way to compute/access the CTC loss gradient without resorting to backward hooks? (currently it’s hidden from manual calls if I understand well)

Can I do (efficiently) sth like:

# eq. 5 
def ctc_alignment_targets(logits, *args, dim = 1):
    log_probs = F.log_softmax(logits.detach(), dim = dim)
    ctc_grad, = torch.autograd.grad(F.ctc_loss(log_probs, *args), (logits,))
    return log_probs.exp() - ctc_grad

# eq. 13, but have not checked this equivalence in practice
# alignment_targets = ctc_alignment_targets(logits, ...)
# ctc_loss_via_cross_entropy = (alignment_targets * F.log_softmax(logits, dim = 1)).mean()

I can imagine some other ways to accomplish this (especially if modifications need not to be differentiated through), but is there a clean way to manually compute gradient of some known function?

There isn’t a way to call ctc’s backward independently of the forward (and not in general, either).
This limits you to not manipulating the inputs to the forward that are passed to the grad.
You can wrap the call to ctc loss into a custom autograd function (detaching + enable_grad + storing the result as intermediate and returning a detached) and then call that the backward on that during the backward of the autograd function.

Personally, I’d think that it’d be cool to expose all backwards in a standard way (and also get rid of sticking them in the Functions.cpp template), but that never gained traction with anyone else.

Best regards

Thomas

1 Like

I’ve tried the following:

def ctc_alignment_targets(logits, targets, input_lengths, target_lengths, blank, dim = 1):
    with torch.enable_grad():
        logits = logits.detach().requires_grad_()
        log_probs = F.log_softmax(logits, dim = dim)
        ctc_loss = F.ctc_loss(log_probs.permute(2, 0, 1), targets, input_lengths, target_lengths, blank, reduction = 'sum')
        ctc_grad, = torch.autograd.grad(ctc_loss, (logits,))
        temporal_mask = (torch.arange(log_probs.shape[-1], device = input_lengths.device, dtype = input_lengths.dtype).unsqueeze(0) < input_lengths.unsqueeze(1))[:, None, :]
        alignment_targets = log_probs.exp() * temporal_mask - ctc_grad
        return log_probs, ctc_loss, alignment_targets

log_probs, ctc_loss, alignment_targets = ctc_alignment_targets(logits, ...)
ctc_loss_via_mle = (-alignment_targets * log_probs).sum()
print(ctc_loss, ctc_loss_via_mle)

Something it not quite right, the two quantities don’t match. Maybe did sth wrong, but the autograd.grad approach worked (even without a custom autograd function, though not sure if it’s same speed as the custom autograd function approach you suggested)

@tom If we figure this out, it’s possible to do label smoothing / class weighting in CTC - rather cool :slight_smile: Gradient of NLL wrt softmax logits indeed seems to be like the paper claims, but for CTC grad wrt activations - needs checking. If you give it a look, let me know :slight_smile:

So I didn’t look at this in great detail, but from a cursory look, it seems that you need to change the formulae inside the reductions, which would mean that you would need custom kernels.
That said, I think you can mostly copy-paste the native CTC code in your own module (in part because I never got around to switching the pointwise part to TensorIterator which would it make more efficient).

Best regards

Thomas

I was thinking of a trick. If existing CTC gradient (wrt logits if the paper is correct) is of form log_probs.exp() - somequantity, then we can retrieve somequantity (by undoing the subtraction) and use it in NLL loss. Like that it should be possible to sidestep modifying the original CTC. But something is not quite right, my code snippet does not produce the same loss value via NLL.

I did a test, and the snippet above actually works! modulo this bug: https://github.com/pytorch/pytorch/issues/31557

This is cool and opens a way towards easy experimentation of CTC modifications! It would be cool if there was a switch of ctc_loss so that we don’t have to do lot_probs.exp() (also not sure if it’s best to do F.log_softmax(...).exp() or do F.softmax(...) again)

My test and impl: https://gist.github.com/vadimkantorov/73e1915178f444b64f9ef01a1e96c1e4

Also, what would be your thoughts about factoring CTC Viterbi path aglorithm out of existing CTC implementation? Should that be doable? Or easier to recode it from scratch?

1 Like