CTC loss comparison

Hi

I’m confused with CTC losses available right now. I think the most popular binding (right now) is sean naren warp-ctc binding. But as Sean recommends here, we should migrate to use this binding instead. Meanwhile native pytorch CTC loss introduced in Pytorch1.0 is another suggestive option.

Is there any fair comparison between these bindings and native Pytorch1.0 CTC? Which one is faster? And what are most notable differences between warp-ctc and Pytorch1.0 CTC implementation? What are pros and cons of using pytorch CTC instead of warp-ctc?

@SeanNaren @tom

PyTorch 1.0 CTC loss is faster.
You’re pretty safe if you use CuDNN loss (you can use print (evaluated_loss.grad_fn) to check if it says something with cudnn or something with native), but there appears to be a bug I haven’t quite hunted down yet when you use the native implementation and inputs of varying length.
I don’t suggest to use my WarpCTC bindings, Sean also has PyTorch 1.0 bindings (I think in a branch, @jinserk contributed to them).
Not that PyTorch’s CTC loss isn’t a drop-in replacement for WarpCTC. There is “infinite losses are not clipped / produce NaN gradients” and “PyTorch uses a different scaling when you use the ‘mean’ reduction option”.
I hope to fix the bug mentioned above this year (it’s somewhere in the issue list, but it might be closed)…

Best regards

Thomas

1 Like

Agree with what Thomas said!

We’ve ran into the issues Thomas mentioned using the ctc loss in Pytorch 1.0 so still using warp-ctc (which now has 1.0 support). Once these issues have been addressed i’ll migrate deepspeech.pytorch to use the built in loss function.

1 Like

Hello Thomas,

I’m using Awni Hannun codebase to do phoneme recognition which uses his pytorch 0.4 warp-ctc bindings. I have recently started using pytorch’s native ctc loss function in pytorch 1.6, and I noticed that the convergence times for pytorch’s native loss function seem to be about 2-3 times slower than awni’s warp-ctc loss function.

I haven’t done thorough convergence comparisons of the two loss functions, but given your statement that “PyTorch 1.0 CTC loss is faster” I wanted to ask: is the native pytorch ctc loss function is faster than your warp-ctc bindings?

If the pytorch native loss function is faster than your bindings then perhaps there is something in my hyperparameter selection that is making the convergence of the pytorch native loss function slower.

Thanks for your time @tom!