CTCLoss performance of PyTorch 1.0.0

Hi,

I’m working on a ASR topic in here and recently I’ve changed my code to support PyTorch 1.0.0. It used @SeanNaren’s warp-ctc, however, when I replace its CTCLoss function to PyTorch’s brand new one, the training becomes not being converged well. The loss reduced to small, but the label error rate or word error rate were not improved, and all output sequence of the trained model contains only blanks. Does anybody have this similar like me?

I’d like to test the both of CTCLosses but don’t know is there any simple and reliable testing example for CTC. Please let me know if you have any good example.

4 Likes

Hi,

thanks!
Can you try to not sum in the loss function, but outside and see if the losses agree?
If not, can we try to identify an example that is off?

Best regards

Thomas

Thanks @tom for replying! So do you mean to look if sum(loss1) == loss2 when loss1 = nn.CTCLoss(reduction='none') and loss2 = nn.CTCLoss(reduction='sum')? I did use reduction='elementwise_mean', but will try and let you know.

Yeah, so with reduction=‘none’ you should get the negative log likelihood per element and with reduction=‘sum’ you’ll get the sum of them (which I believe is the default for @SeanNaren’s warp ctc.
Of course, if you had an example where you say “this loss is wrong” or “that gradient is wrong”, that would be superb. (You could run your network and have the two calculated on .detach().requires_grad_()-ed outputs and compare loss and gradient until you find an example where you say “this is wrong”.)
There used to be a bug with the gradient calculation, we now have a gradcheck that attempts to trigger all codepaths, so I would hope it is OK, but if you find an example, I’ll be more than happy to investigate that and fix it for good.

Best regards

Thomas

Thanks @tom. I made a simple test to calculate some ctc losses from @SeanNaren’s warp-ctc (with the option of size_average=True and length_average=True) and nn.CTCLoss as follows:

frame_len: [473, 225]  target_len: [100, 38]
	warp_ctc: 4.356187343597412	CTCLoss(elem_mean): 23.244503021240234	CTCLoss(sum.item()): 3040.61865234375	CTCLoss(none): [2054.897705078125, 985.7210693359375]

frame_len: [793, 203]  target_len: [125, 44]
	warp_ctc: 4.358583927154541	CTCLoss(elem_mean): 23.780426025390625	CTCLoss(sum.item()): 4341.14892578125	CTCLoss(none): [3469.86328125, 871.2857666015625]

frame_len: [431, 121]  target_len: [104, 20]
	warp_ctc: 4.288755416870117	CTCLoss(elem_mean): 22.03105926513672	CTCLoss(sum.item()): 2367.3935546875	CTCLoss(none): [1839.9967041015625, 527.3968505859375]

frame_len: [391, 128]  target_len: [77, 17]
	warp_ctc: 4.3624701499938965	CTCLoss(elem_mean): 27.947921752929688	CTCLoss(sum.item()): 2264.12255859375	CTCLoss(none): [1686.162841796875, 577.959594726562

As you can see, all values are with batch_size 2. I’m sure the values are correctly calculated as intended. I guess the elementwise_mean loss led my training to somewhat different from earlier. I’m not sure what the effect is, but now, at least I know how to make the losses from warp-ctc and from the nn.CTCLoss same. I’ll report here if I found something strange further. Thank you!

So in the first example, you could take the sum 3041 and devide by 473+225=698 to get to 4.356 as warp_ctc does or you can take 2055/100=20.55 ; 986/38=25.94 and average those to 23.24 you get with elem_mean.
The two weight short and long examples a bit different, I mainly used mine to have the batch item’s losses independent of each other (if you take Sean’s version, you get a different average loss if you re-shuffle the items into pairs and average).
The loss reduction options are a bit of a funny match with CTC loss, but if you want, you should be able to do the scaling that best suits your goals.

Best regards

Thomas

Yeah. what I understood earlier was that denominator of the averaging is just replaced from frame length to target length, and now I get to know what the way of elementwise_mean really is, which was the key. Thank you again for your time, @t-vi!

This is another topic of the question, but I have a case to report here. To ensure that warpctc_pytorch.CTCLoss and nn.CTCLoss produce the same loss, I ran my training with the following setup:

import torch.nn as nn
import warpctc_pytorch as wp
...    
self.loss0 = nn.CTCLoss(blank=0, reduction='sum')
self.loss1 = wp.CTCLoss(blank=0, size_average=True, length_average=True)
...
xs, ys, frame_lens, label_lens = data
ys_hat, frame_lens = self.model(xs, frame_lens)
d = frame_lens.sum().float().cuda()
loss0 = self.loss0(ys_hat, ys, frame_lens, label_lens).div_(d)
loss1 = self.loss1(ys_hat, ys, frame_lens, label_lens) 
if (torch.abs(loss0 - loss1).item() > 1e-3): 
    logger.info(f"nn.CTCLoss: {loss0.item()}") 
    logger.info(f"wp.CTCLoss: {loss1.item()}")
loss = loss0
if torch.isnan(loss) or loss.item() == float("inf") or loss.item() == -float("inf"):
    logger.warning("received an inf loss, setting loss value to 0")
    loss.data = torch.tensor(0.).cuda() if self.use_cuda else torch.tensor(0.)
...

however, after a while during the training, it stopped and showed

2018-10-18 20:38:21,276 [INFO ] nn.CTCLoss: inf
2018-10-18 20:38:21,652 [INFO ] wp.CTCLoss: 0.6684228777885437
2018-10-18 20:38:21,987 [WARNING] received an inf loss, setting loss value to 0

so, I wonder that there exists any unstable point inside the nn.CTCLoss compared to the others to make them different as shown above. I obtained nn.CTCLoss in gpu mode, not using cudnn.

Hi,

one of them is certainly wrong and I would like to investigate it in detail. Would it be possible to dump a set of inputs when you get that and send a zip via mail or so?

Best regards

Thomas

I sent an email to you with the dumped input where nn.CTCLoss produces inf while the warp-ctc's gives a bounded real number. Please let me know if you figure out something. Thank you!

Hi,

thank you.
So if you look at the loss with reduction=‘none’, you see that the element number 4 of the batch has infinite loss. This is because the input length is smaller than the target length, i.e. you cannot possibly get the an alignment between input and target (the actual condition is a bit stricter and more elaborate when the target has repeated labels because the network then needs to emit a blank in between and needs a longer input, so you get a necessary condition input length >= target length + repetitions for the loss to be finite - and when you have softmax it also is sufficient barring numerical overflow to infinity).
I’m not entirely sure what warpctc does, but from my recollection it may just report 0 instead or so. Note that the gradient of this will be NaN for the inputs in question, maybe it would be good to optionally clip that to zero (which you could do with a backward hook on the inputs now).

Best regards

Thomas

(Edited “<” vs. “>” based on @jinserk’s correction below. Thanks!)

Ahhh… I didn’t realize that there exists any data pair where its input length is shorter than the target length (is this correct that you wanted to say? you mentioned that input length > target length instead) I’ll filter out and try another training. Thank you so much!

Jin

Hi @tom,

Can I ask you that the nn.CTCLoss grad_outputs are connected to the gradients caculation? According to here, @SeanNaren also experienced a longer convergence. Still, I’m comparing the training process of my code with nn.CTCLoss and warpctc.CTCLoss but pretty sure that the training with nn.CTCLoss is converging slowly than the other. Of course I filtered out the singular data that their input lengths < 2 * target lengths. If this is the origin of difference and the implementation of nn.CTCLoss is correct, then I need to know how I can make the convergence as fast as the warpctc even if I use the official nn.CTCLoss.
Thank you!

Yes, they are, but if you put in ones, that should be OK, no?
You could try to see if you find something for the gradient using a similar method as the one you used for the forward, but I would be surprised if we put out a totally different gradient (I think we do pass gradcheck for all code paths by now, but there might be other bugs or the gradient might not be as precise in fp32 as we might wish - in order to speed up computation, it moves out of log space a bit earlier than one might do for maximum precision, it’s a trick I took from a talk by one of the people who implemented CuDNN CTC loss at a GTC a while back).

I don’t think backpropagating the output gradient instead of just ones should make a difference (well, unless you backpropagate output gradients that aren’t all ones), but you’d have to ask Sean about his experience.

Best regards

Thomas

What does the “put in ones” mean? Is it loss.backward(loss.grads.new_ones(loss.grads.size())) or something else?
I had never seen such usage, so I really don’t know about that. Could you let me know how to do that?

I was under the impression that Sean’s warpctc always assumes that you do loss.backward() directly on the CTC loss, i.e. the gradient_out of loss is 1, which is the same as not reducing and using loss.backward(torch.ones_like(loss)). Now if you do funny things loss2 = exp(-5*loss) + something_else, you will have different values backpropagating.

Best regards

Thomas

When I did loss = 1, it complains that *** AttributeError: 'int' object has no attribute 'backward'. Do we have any method to replace a Variable’s data to 1 without touching the grads? Sorry for bothering you with the several questions, and thank you!

Sorry, the “immediate backward” probably is a bit off track.
If you want to validate gradients, you can do something like:

log_probs.retain_grad_()
lp2 = log_probs.detach().requires_grad_()
loss1 = loss_fn1(log_probs, ...)
loss2 = loss_fn2(lp2, ...)

and then compare lp2.grad to log_probs.grad.
If you find anything suspicious, please let me know.

Best regards

Thomas

Hi @tom,
Could I ask you one more question? According to warpctc, the probs is the linear scale without LogSoftmax, and the CTCLoss makes the probs into log scale inside. And nn.CTCLoss requires LogSoftmax output, not the log scale of the output. Here the difference is “softmax” existence, so I’d like to know whether the LogSoftmax output is the correct one I have to give to the loss function.
Thanks!

Yes, for nn.CTCLoss you should pass log probabilities (i.e. have log softmax applied). Log-Softmax is idempotent (log_softmax(log_softmax(a,1),1)==log_softmax(a,1), that is why you can feed the log probabilities to warpctc as well.

Best regards

Thomas