Higher order gradients of CTCLoss

I’m using the CTCLoss in PyTorch version 1.0.0.

My vague understanding from the source and discussions I’ve read is that it wraps some external cpp modules (linking to cudnn), and implements its own backwards() rather than relying on pytorch’s autograd.

I therefore assume higher order gradients (e.g. HVPs) through CTCLoss won’t work.

For example, running:

# CTC Loss test
import torch

ctc_loss = torch.nn.CTCLoss()
log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()

targets = torch.randint(1, 20, (16, 30), dtype=torch.long)

input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10,30,(16,), dtype=torch.long)

loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

(grads, ) = torch.autograd.grad(loss, log_probs, create_graph=True)  # ∇J
grads_summed = torch.sum(grads, (0,1,2))  # equilvant to ∇J⋅1
torch.autograd.grad(grads_summed, log_probs)  # attempt HVP ∇(∇J⋅1)

results in:

Runtime Error: derivative for _ctc_loss_backward is not implemented

Is there any way around this other than to take on the daunting task of implementing backward() for ctc_loss_backward?

1 Like

Indeed. I’d recommend starting from the original paper, you could differentiate equations 15 and the derivative below it to compute the second derivative. I’ve tried to comment the source to make it easy to follow with the paper in mind (but remember that alpha and beta are computed in log space).
I’m not sure whether it will work well numerically - even for the first derivative of ctc_loss, the numerical precision can be touchy.
I’d be interested to see your results.

Best regards


1 Like

@mebrunet Any progress on this? Would love to know if you made any headway on implementing backward for ctc_loss_backward.

I haven’t made the time yet to try… Our work took a turn, and I haven’t needed it since. Will post here if ever I take a crack at it.

Hi @tom ! I’m trying to do double backward for ctc loss. Is there a way to approximate ctc gradient with finite differences? And is there a way to calculate ctc backward in double type so that there are no precision problems

I’m not aware of a double backward for ctc loss being available
By their very nature, finite differences compute approximate directional derivatives (so jacobian-vector products if you want, what you would get with forward mode) rather than gradients (aka vector-jacobian products). As such generally compute-intensive to approximate gradients with finite differences because you would need as many evaluations as you have inputs.
You can use CTC loss with doubles just by passing them, enabling this was one of my goals with having an integrated open source implementation.

Best regards