CTCLoss performance of PyTorch 1.0.0

I also down with CTCLoss, I don’t know how to fix it. plz help me point it out.

Problem: ASR

Questions:

  1. These results are not correct, how to correctly update parameters?
  2. how to decode CTC to calculate Acc?

Data input:
X: AllMFCCs: (48840, 247, 20): batchs, mfcc_len, feature, (batch will be 16 or 32 or more each)
Y: char_vec: (48840, 30), values range: 0 … 6
Ylen: char_length: (48840,), values range 1 … 30

Define Net to train:

'''ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    def forward(self, x):        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=256*8):
        super(ResNet, self).__init__()
        Nsize=32
        self.in_planes = Nsize
        self.conv1 = nn.Conv2d(1, Nsize, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(Nsize)
        self.layer1 = self._make_layer(block,Nsize,num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=1)
        self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2)
        self.linear = nn.Linear(3840, num_classes) #
        self.Smax   = nn.Softmax(dim=-1)
    def _make_layer(self, block, planes, num_blocks, stride):
#         print('_make_layer:',block, planes, num_blocks, stride)
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)));        #print(1,out.size())
        out = self.layer1(out);         #print(2,out.size())
        out = self.layer2(out);         #print(3,out.size())
        out = self.layer3(out);         #print(4,out.size())
        out = self.layer4(out);         #print(5,out.size())
        out = F.avg_pool2d(out, 4);     #print(6,out.size())
        out = out.view(out.size(0), -1);#print(7,out.size())
        out = self.linear(out);         #print(8,out.size())
        MFs=[]
        for mf in out:
            mf=mf.view(-1,8)
            MFs.append(mf)
        out = torch.stack(MFs)
#         out = out.view(out.size(0),-1,8) # Train 3: torch.Size([32, 50, 8])
#         print(out)
        out = self.Smax(out);      
        return out

def ResNet18():    return ResNet(BasicBlock, [2,2,2,2])

def test():
    net = ResNet18()
    # bat=AllMFCCs[0:BatchSize]
    # bat out: torch.Size([32, 1, 247, 20]) 
    y = net(bat)
    print(y.size())
#     print(y)
#     print(net)

test()
import time;print(time.asctime())

Code in training Main:

from torch import nn 
from tensorflow.python.ops import array_ops
from torch import nn, autograd, FloatTensor, optim

ctc_loss       = nn.CTCLoss(reduction='elementwise_mean')
net = ResNet18()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

optimizer = optim.SGD(net.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
print(time.asctime())

net.train()
train_loss = 0
correct = 0
total = 0

k=0
BatchSize=16
for batch_idx in range(0, len(AllMFCCs),BatchSize):
    Batch_Input   = AllMFCCs[batch_idx:BatchSize+batch_idx]
    target_lengths= char_len[batch_idx:BatchSize+batch_idx]
    targets       = char_vec[batch_idx:BatchSize+batch_idx]
    targets       = targets+1
    targets       =torch.Tensor(targets).long()
    target_lengths=torch.Tensor(target_lengths).long()
    
    optimizer.zero_grad()
    Batch_Input1=ConvertNpArray3D_2Tensor4D(Batch_Input)
    # bat inp: (32, 247, 20)
    # bat out: torch.Size([32, 1, 247, 20]) 
    
    Batch_Input1=autograd.Variable(Batch_Input1)
    targets=autograd.Variable(targets)
    
    
        
    Batch_Input1,targets = Batch_Input1.to(device), targets.to(device)
    
    log_probs=net(Batch_Input1)
    log_probs=log_probs.detach().requires_grad_()
    
    log_probs = log_probs.transpose(1,0) # 500,32,8
    input_lengths=torch.full  ((log_probs.shape[1],), log_probs.shape[0],   dtype=torch.long);
    
    input_lengths=autograd.Variable(input_lengths)
    target_lengths=autograd.Variable(target_lengths)
    
    loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    print(k,'loss:',loss.item(),train_loss)
    k+=1
    if k==100: break
    
#     _, predicted = log_probs.max(1)
#     total += targets.size(0)
#     correct += predicted.eq(targets).sum().item()


#     progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#         % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
#     print('\rTrain:',batch_idx,'/', len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#         % (train_loss/(batch_idx+1), 100.*correct/total, correct, total), end='   ',flush=True)
print('Done!')
# train()

Results:

0 loss: -8.53072738647461 -8.53072738647461
1 loss: -8.66711711883545 -17.19784450531006
2 loss: -8.59152889251709 -25.78937339782715
3 loss: -8.418535232543945 -34.207908630371094
...........................
97 loss: -8.594472885131836 -840.0682668685913
98 loss: -8.93405532836914 -849.0023221969604
99 loss: -8.47213363647461 -857.4744558334351

Hi @ntanh ,

You have to do LogSoftmax instead of Softmax at the output of the net. I’m not sure what metric you want to calculate for the accuracy, but CTCLoss is just a loss function, so you might need to calculate another LER (label error rate) using edit distance or something else in eval mode.

Jin

2 Likes

Thank @jinserk,

  • To calculate ACC, you’re right, it should be WER or LER. But I mean, for example in Keras have K.ctc_decode, in Pytorch which function can decode nn.CTCLoss?
  • I just replaced Softmax by LogSoftmax as you suggest, this is results:
Thu Oct 25 09:16:51 2018
Net out shape: [batch,256,8]
 self.Smax   = nn.LogSoftmax(dim=2)
0 loss: 33.141990661621094 33.141990661621094
1 loss: 34.519630432128906 67.66162109375
2 loss: 33.144351959228516 100.80597305297852
3 loss: 32.033241271972656 132.83921432495117
4 loss: 34.247257232666016 167.0864715576172
5 loss: 30.669713973999023 197.7561855316162
................
197 loss: 39.64136505126953 6625.480464935303
198 loss: 37.10675811767578 6662.5872230529785
199 loss: 38.627960205078125 6701.215183258057

So I think it still not correct. I am very wondering that why my Network Can’t update Weight? (last minibatch loss is the same as first minibatch)

These are results I change dim=2 to dim=1 in LogSoftmax:

Thu Oct 25 09:41:22 2018
Net out shape: [batch,256,8]
 self.Smax   = nn.LogSoftmax(dim=1)
Batch:0 - Loss:   99.08631896972656 - Total Loss:   99.08631896972656
Batch:1 - Loss:    95.8052978515625 - Total Loss:  194.89161682128906
Batch:2 - Loss:   95.28765869140625 - Total Loss:   290.1792755126953
Batch:3 - Loss:  100.45629119873047 - Total Loss:   390.6355667114258
Batch:4 - Loss:  103.56320190429688 - Total Loss:  494.19876861572266
Batch:5 - Loss:   95.95877075195312 - Total Loss:   590.1575393676758
Batch:6 - Loss:    99.9146728515625 - Total Loss:   690.0722122192383
Batch:7 - Loss:   95.26667785644531 - Total Loss:   785.3388900756836
Batch:8 - Loss:  106.40909576416016 - Total Loss:   891.7479858398438

Still not correct :frowning:
plz @jinserk, @tom help me :smile:

@jinserk You are right, we need to add log_softmax to get right loss. Before doing this, I get increasing loss and it began from negative value. :joy: since the new release of pytorch 1.0.0, a lot of people don’t know how to use the offical ctcloss properly. @tom would you write an offical tutorial about how to use it and the difference between @SeanNaren 's warpctc ? thanks.

1 Like

I get NaNs from nn.CTCLoss even when input_length > target_length. I can’t be sure what warp_ctc is giving in this case as i’m having trouble installing warp_ctc. Are there any other cases I should take care of? I’m doing something like this now

criterion = warpctc_pytorch.CTCLoss()
out = model(inputs)
loss = criterion(out, targets, sizes, target_sizes)
loss = loss / inputs.size(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()

converting this to

criterion = torch.nn.CTCLoss(reduction="none")
out = model(inputs)
loss = criterion(out, targets, sizes, target_sizes)
loss=loss.sum()/inputs.size(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()

@swethmandava as mentioned by others above, torch.nn.CTCLoss takes the output of a Log Softmax, unlike warpctc_pytorch, which takes the non-logits loss.

Maybe you need an out = F.log_softmax(model(inputs)) ? (in case your model doesn’t have a log_softmax as it’s last operation)

1 Like

I have add a log_softmax to my network output, it get nan after some epochs, but it’s ok in @SeanNaren’s warpctc

I’ll write a tutorial + drop in replacement, but not today. :wink:

@jun_zhou That is likely due to feeding “impossible” inputs to CTCLoss. WarpCTC zeros them. I’ll eventually provide a notebook with a wrapper that does that.

@swethmandava input_length > target_length is only necessary, but you actually need input_length >= target_length + number of repetitions in target . That is because your network needs to output x<blank>x to match a target of xx.

My gut feeling is that people will likely want an option to zero infinite loss eventually. My apologies for not getting this into 1.0.

Best regards

Thomas

@swethmandava, you could register backward hook like:

def backward_hook(self, grad_input, grad_output):
    for g in grad_input:
        g[g != g] = 0   # replace all nan/inf in gradients to zero

model.register_backward_hook(backward_hook)

then it will work as similar as the warpctc_pytorch. Please notify that this could distort the gradient direction as @tom mentioned.

4 Likes

Thanks, masking the NaNs works. I did something like this =>
loss_batch = torch.where(loss_batch != loss_batch, torch.zeros_like(loss_batch), loss_batch)

Do we want to mask infs as well? I don’t see a problem right now without.

But do you get NaNs in the forward when you mask them in the backward?
That definitely should not happen and I’d be very keen on seeing the inputs that causes it.

One might try if using loss == inf (before reduction) works as a criterion for zeroing the loss and whether you can get performance benefits from only having a single comparison per image.

Best regards

Thomas

HI, @jinserk @tom I encountered this case, some box’s width after being resized is smaller than the target length, then I use ~torch.isinf(loss) to filter out the inf items and divide it by target length , then compute mean. After optim.backward(), weights become nan. how can I handle it ?

I have add a new comment is this git hub thread, which I think will be the reason that cause nan

Good insight in this thread, thanks guys! To try make things a bit easier I’ve made a script that uses the builtin ctc loss function and replicates the warp-ctc tests. Seem to give the same results when you run pytest -s test_gpu.py and pytest -s test_pytorch.py but does not test the above issue where we have two difference sequence lengths in the batch. The test is here: https://github.com/SeanNaren/warp-ctc/blob/feature/pytorch_migration/pytorch_binding/tests/test_pytorch.py

I’ve taken these changes and used the built in ctc loss with deepspeech.pytorch to train an AN4 model. It did not converge at all, and the loss went a bit wild so definitely a few things to investigate. I’ll try do some more digging!

EDIT: should mention thanks to Jinserk warp-ctc is 1.0 compat on the pytorch_bindings branch!

I have a similar issue. My acoustic model does not converge using the nn.CTCLoss function. The model used 4 BiLSTM layers and one linear layer. When I looked into the gradient values of the weights in both LSTM and linear layers, they were within the range of 1e-6 ~ 1e-11, while the gradient values in a tensorflow implementation was about 1e-1~1e2. I wonder if this is caused by the new loss function.

PS: I did use log_softmax on the loss function input.
The tensorflow implementation worked normally.
I did try setting ‘reduction’ to different values.

Regarding text-to-speech I can report that pytorch’s built-in ctc loss always, always(!) (changed hparams) runs into the same bad local minimum (combination of non-ctc blanks " " and "E"s) and does not improve from there. Under the same conditions the training with warp-ctc loss converges smoothly as I have also seen it when using TensorFlow’s ctc loss.

Under the assumption that there’s no bug from my side, what conclusions could one draw from that for the loss and gradients?

If you use varying input lengths: We fixed a bug in it this week that will be in 1.0.1, can you try master, please?

Best regards

Thomas

I finally sent a PR to zero out infinite losses, so with any bit of luck the backward hook should not be needed if you set zero_infinity=True. Your comments to the PR are greatly appreciated.

Best regards

Thomas

2 Likes

@tom
Sorry for the late response, it works! Thanks!

Another question regarding parallelization of CTC Loss:
I’m wondering how to parallelize the loss computation with log_probs being expected in data format (T, N, C) and the other inputs in batch-first mode. Since DataParallel expects a common dimension to scatter along, it does not seem possible to me currently.

Could you transpose the log_probs just before feeding into CTC loss and run with batch first before?

Best regards

Thomas