CTCLoss predicts blanks after a few batches

Hello,

I’m trying to port over a CTC network from Keras. I’ve based the model off of https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py (basically replacing the warp_ctc.CTCLoss with the pytorch CTCLoss because warp_ctc won’t compile). This is using pytorch version 1.0.1, CUDA version 9.0.

Training code is:

net = CRNN(32, 3, len(labels), nh=256)
net.to(device)

ctc_loss = nn.CTCLoss(blank=blank_ind)
ctc_loss.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.09)

for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, data_labels = data
        for k,v in inputs.items():
            v = v[0]
            inputs[k] = v.to(device)
        for k,v in data_labels.items():
            v = v[0]
            data_labels[k] = v.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        predicteds = net(inputs)
        log_probs = predicteds.log_softmax(2)

        input_lengths = torch.full((inputs['image'].shape[0],),
                                   log_probs.shape[0], dtype=torch.int)

        loss = ctc_loss(log_probs, data_labels['the_labels'], input_lengths, data_labels['label_length'])

        loss.backward()
        optimizer.step()

Dimensions of variables (100 is the batch size):

  • Input images = (3, 32, 248)
  • log_probs = (63, 100, 250)
  • data_labels['the_labels'] = (100, 12)
  • input_lengths = (100,)
  • data_labels['label_length'] = (100,)

Sample format of the labels (where 249 is the blank index passed to CTCLoss):

labels tensor([[  2,  98, 101,  83,   0, 249, 249, 249, 249, 249, 249, 249],
        [  2, 124,  13,  41,   0, 249, 249, 249, 249, 249, 249, 249],
        [  2,  24, 113,  13, 109,   0, 249, 249, 249, 249, 249, 249],
        [  2, 112, 114, 124,   2, 249, 249, 249, 249, 249, 249, 249],
        [  2,  28,  30,  76,   0, 249, 249, 249, 249, 249, 249, 249],
        [  0,  41,  14,  98,   2, 249, 249, 249, 249, 249, 249, 249],
        [  2,  41, 125,  13,  41,   0, 249, 249, 249, 249, 249, 249],
        [  0,  76,  13, 124,   2, 249, 249, 249, 249, 249, 249, 249],
        [  2,  24, 125,  13,  83,   0, 249, 249, 249, 249, 249, 249],
        [  0,  41,  43,  35,   2, 249, 249, 249, 249, 249, 249, 249],
        [  2, 112, 114,  35,   2, 249, 249, 249, 249, 249, 249, 249]],
       device='cuda:0')
labels_length tensor([5, 5, 6, 5, 5, 5, 6, 5, 6, 5, 5, 6, 5, 6, 6, 5, 5, 5, 5, 5, 6, 6, 6, 5,
        5, 5, 5, 5, 5, 5, 5, 6, 6, 5, 6, 6, 5, 5, 5, 5, 6, 6, 5, 5, 5, 5, 3, 5,
        6], device='cuda:0')

What I’ve tried:

  • Setting the blank index to different values (either the length of the labels or 0)
  • Using a different optimizer/smaller learning rates (suggested in CTCLoss predicts all blank characters, though it’s using warp_ctc)
  • Training on just input images that have a sequence (rather than images with nothing in them)

In all cases the network will produce random labels for the first couple of batches before only predicting blank labels for all subsequent batches.

Is there anything that I’m missing here?

How did you solve the problem? I am facing a similar error when I am using pytorch ctc loss in
Pytorch 1.0.1.post2

I haven’t managed to solve it yet, tried upgrading to CUDA 10.0, as well some other tweaks to the optimizer, but it still predict blanks after some time.

If you have a workable example, I’d take a look. It doesn’t sound like any issue I’m aware of. (Which are: 1) NaN if you don’t pass zero_infinity=True (available in master / the nightly builds listed under “preview”) and have impossible targets. 2) A bad loss/gradient for all blank targets.)

Best regards

Thomas

Sure, thanks so much for taking a look! Here’s one of the smallest datafiles (https://www.dropbox.com/s/d3e147uwnsavx3v/8184.hdf5?dl=0). I think you should just need to update the relevant directories to get it to run.

Other potentially relevant aspects of the data:

  • Batches are hacked a bit to get the file IO faster
  • Many unused labels, because I’d like to not train from scratch as new labels are needed

Let me know if there’s anything else you need!

The full script is as follows:

import os
import torch
import torchvision
import torchvision.transforms as transforms
import h5py
import numpy as np
import torch.utils.data as data
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import random
import itertools

working_dir = r'E:\Data\Overwatch\models\kill_feed_ctc'
os.makedirs(working_dir, exist_ok=True)
log_dir = os.path.join(working_dir, 'log')
TEST = True
train_dir = r'E:\Data\Overwatch\training_data\kill_feed_ctc'

cuda = True
seed = 1
batch_size = 100
test_batch_size = 100
epochs = 10
lr = 0.01
momentum = 0.5
log_interval = 10


def labels_to_text(ls):
    ret = []
    for c in ls:
        if c >= len(labels):
            continue
        ret.append(labels[c])
    return ret


def decode_batch(out):
    ret = []
    print(out.shape)
    for j in range(out.shape[0]):
        out_best = list(np.argmax(out[j, 2:], 1))
        print(out_best)
        out_best = [k for k, g in itertools.groupby(out_best)]
        print(out_best)
        outstr = labels_to_text(out_best)
        print(outstr)
        ret.append(outstr)
    return ret


def load_set(path):
    ts = []
    with open(path, 'r', encoding='utf8') as f:
        for line in f:
            ts.append(line.strip())
    return ts


labels = load_set(os.path.join(train_dir, 'labels_set.txt'))

class_count = len(labels)
blank_ind = len(labels) - 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


class TrainDataset(data.Dataset):
    def __init__(self):
        super(TrainDataset, self).__init__()
        self.data_num = 0
        self.data_indices = {}
        count = 0
        for f in os.listdir(train_dir):
            if f.endswith('.hdf5'):
                with h5py.File(os.path.join(train_dir, f), 'r') as h5f:
                    self.data_num += h5f['train_img'].shape[0]
                    self.data_indices[self.data_num] = os.path.join(train_dir, f)
                count += 1
                if count > 1:
                    break
        self.weights = {}
        print('DONE SETTING UP')

    def __len__(self):
        return int(self.data_num / batch_size)

    def __getitem__(self, index):
        start_ind = 0
        real_index = index * batch_size
        for i, (next_ind, v) in enumerate(self.data_indices.items()):
            path = v
            if real_index < next_ind:
                break
            start_ind = next_ind

        real_index = real_index - start_ind
        inputs = {}
        outputs = {}
        with h5py.File(path, 'r') as hf5:

            inputs['image']= torch.from_numpy(np.transpose(hf5['train_img'][real_index:real_index+batch_size, ...],  (0, 3 ,2, 1))).float()
            labs = hf5["train_label_sequence"][real_index:real_index+batch_size, ...].astype(np.int32)
            #labs += 1
            labs[labs > blank_ind] = blank_ind
            outputs['the_labels'] = torch.from_numpy(labs).long()
            print(hf5["train_label_sequence_length"][real_index:real_index+batch_size].shape)
            outputs['label_length'] = torch.from_numpy(hf5["train_label_sequence_length"][real_index:real_index+batch_size]).long()
            print(outputs['label_length'].shape)

            # For removing all blank images
            #inds = outputs['label_length'] != 1
            #inputs['image'] = inputs['image'][inds]
            #outputs['the_labels'] = outputs['the_labels'][inds]
            #outputs['label_length'] = outputs['label_length'][inds]

        return inputs, outputs


class TestDataset(data.Dataset):
    def __init__(self):
        super(TestDataset, self).__init__()
        self.data_num = 0
        self.data_indices = {}
        count=0
        for f in os.listdir(train_dir):
            if f.endswith('.hdf5'):
                with h5py.File(os.path.join(train_dir, f), 'r') as h5f:
                    self.data_num += h5f['val_img'].shape[0]
                    self.data_indices[self.data_num] = os.path.join(train_dir, f)
                count += 1
                if count > 1:
                    break

    def __getitem__(self, index):
        start_ind = 0
        real_index = index * test_batch_size
        for i, (next_ind, v) in enumerate(self.data_indices.items()):
            path = v
            if real_index < next_ind:
                break
            start_ind = next_ind
        real_index = real_index - start_ind
        inputs = {}
        outputs = {}
        with h5py.File(path, 'r') as hf5:
            inputs['image']= torch.from_numpy(np.transpose(hf5['val_img'][real_index:real_index+batch_size, ...], (0, 3 ,2, 1))).float()

            outputs['the_labels'] = torch.from_numpy(hf5["val_label_sequence"][real_index:real_index+batch_size, ...].astype(np.int32)).long()
            outputs['label_length'] = torch.from_numpy(np.reshape(hf5["val_label_sequence_length"][real_index:real_index+batch_size], (-1, 1))).long()
        return inputs, outputs

    def __len__(self):
        return int(self.data_num / test_batch_size)


train_set = TrainDataset()
trainloader = torch.utils.data.DataLoader(train_set, batch_size=1,
                                          shuffle=True)
test_set = TestDataset()
testloader = torch.utils.data.DataLoader(test_set, batch_size=1,
                                          shuffle=True)

def imshow(img):
    img = img.cpu()
    npimg = img.numpy()
    img = np.transpose(npimg, (1,2, 0))[:,:, [2,1,0]]/255
    plt.imshow(img)
    plt.show()

# get some random training images
dataiter = iter(trainloader)
inputs, outputs = dataiter.next()
print(inputs['image'].shape)
print(outputs['the_labels'].shape)
# show images
imshow(torchvision.utils.make_grid(inputs['image'][0, :4, ...],nrow=1))

for i in range(4):
    print(outputs['the_labels'][0,i])
    print('Labels', ' '.join(labels[x] for x in outputs['the_labels'][0,i] if x < len(labels)))


class BidirectionalLSTM(nn.Module):
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output


class CRNN(nn.Module):

    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input):
        input = input['image']
        # conv features
        conv = self.cnn(input)
        print(conv.shape)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        print(conv.shape)

        # rnn features
        output = self.rnn(conv)
        print(output.shape)
        return output


net = CRNN(32, 3, len(labels), nh=256)
net.to(device)

ctc_loss = nn.CTCLoss(blank=blank_ind)
ctc_loss.to(device)
#optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
optimizer = optim.Adagrad(net.parameters())

import time
for epoch in range(2):  # loop over the dataset multiple times
    batch_begin = time.time()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        print(i)
        # get the inputs
        #begin = time.time()
        inputs, data_labels = data
        for k,v in inputs.items():
            v = v[0]
            inputs[k] = v.to(device)
        for k,v in data_labels.items():
            v = v[0]
            data_labels[k] = v.to(device)
        #print('Loading data took: {}'.format(time.time()-begin))
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        #begin = time.time()
        predicteds = net(inputs)
        print(predicteds.shape)
        predicteds = predicteds.log_softmax(2)
        print(predicteds.shape)
        input_lengths = torch.full((inputs['image'].shape[0],),
                                   predicteds.shape[0]).long()
        input_lengths.to(device)

        #print('Predicting data took: {}'.format(time.time()-begin))
        decode_batch(predicteds.cpu().detach().numpy())
        #begin = time.time()
        print(predicteds.shape)
        #print('TRUE LABELS', data_labels['the_labels'].shape)
        #print('TRUE LABELS', data_labels['the_labels'])
        #print(input_lengths)
        #print(input_lengths.shape)
        #print(data_labels['label_length'])
        #print(data_labels['label_length'].shape)
        loss = ctc_loss(predicteds, data_labels['the_labels'], input_lengths, data_labels['label_length'])
        print(loss.item())

        #print('Loss calculation took: {}'.format(time.time()-begin))
        #begin = time.time()
        loss.backward()
        optimizer.step()
        #print('Back prop took: {}'.format(time.time()-begin))

        # print statistics
        #begin = time.time()
        running_loss += loss.item()
        #print(loss.item())
        if i % 50 == 49:    # print every 50 mini-batches
            #print(predicteds['hero'])
            #print(labels['hero'])
            print('Epoch %d, %d/%d, loss: %.3f' %
                  (epoch + 1, i + 1, len(train_set), running_loss / i))
            running_loss = 0.0
            print('Batch took: {}'.format(time.time()-batch_begin))
        batch_begin = time.time()



print('Finished Training')

dataiter = iter(testloader)
inputs, outputs = dataiter.next()
for k,v in inputs.items():
    v = v[0]
    inputs[k] = v.to(device)
for k,v in outputs.items():
    v = v[0]
    outputs[k] = v.to(device)

# print images
imshow(torchvision.utils.make_grid(inputs['image'][:4, ...], nrow=1))
with torch.no_grad():
    predicteds = net(inputs)
    predicted_labels = decode_batch(predicteds.cpu())
    for i in range(4):
        print(predicted_labels[i])
        print('Labels', ' '.join(labels[x] for x in outputs['the_labels'][i] if x < len(labels)))

The code needs labels.txt. :slight_smile:

Whoops! Available here: https://www.dropbox.com/s/odsxkquhiu1apio/labels_set.txt?dl=0

Yeah, well, your targets include the CTC blank. That’s not allowed.

I can recommend the Distill article Sequence Modelling with CTC for an accessible explanation to how CTC works.

Best regards

Thomas

Right, so my understanding of how to handle varying length sequences was to have a max length of the targets (12, in this case), and then pad with the blank label at the end, and have a tensor for the length of each of the targets. Is there a better way to handle variable length targets without padding?

I don’t think the blank label should be within the target lengths at all, except for the cases where there’s no labels. I think I tried setting the target length to 0 for those at one point instead of 1, but it threw an error.

This dataset works well enough in Keras where I originally prototyped it. I think the main change in specification there is that Keras has the blank label not in the label set (i.e., blank_ind = len(labels)) whereas Pytorch requires the blank label be in the label set, correct?

Indeed, ctc loss will ignore stuff beyond the target_lengths. However, your lengths seemed to be such that you include blanks before that, maybe the lengths are off?
There is a bug in GPU CTC when you have target length zero, so if you need that after you’ve fixed your inputs, maybe the CPU version is better at the moment.

Best regards

Thomas

I don’t think the blank_ind (in this case 249) appears in the targets within the target_lengths. I did a check via something like:

            for i in range(inputs['image'].shape[0]):
                length = outputs['label_length'][i]
                if length > 1 and blank_ind in outputs['the_labels'][i][:length]:
                    print(outputs['the_labels'][i])
                    print(outputs['the_labels'][i][:length])
                    print(outputs['label_length'][i])
                    raise Exception

in the__getitem__ method, and it never fired. Just so that I’m clear on the behavior of target_lengths, for an item like:

tensor([  0,  76,  81,  98,   1, 249, 249, 249, 249, 249, 249, 249])

with a target length of 5, that should give:

tensor([ 0, 76, 81, 98,  1])

as the actual target within the CTCLoss code?

I tried setting the blank_ind for the CTCLoss to something else that’s never in the targets (like 230), and it started predicting that after a while. I did some tests earlier that set the blank_ind to 0 (and added 1 to all existing labels to offset them), and it did the same behavior of predicting the blank index. I did some tests earlier removing the data items without any labels to see if that affected it, as well as using the CPU version, but the same behavior shows up.

80%+ of your input samples in the first thing that I got are of length one. Those seem to be mostly invalid because you have blank as the target.

Right, so those would be the ones that I would set to target_length=0 if that didn’t throw an error. There is code to not include them in the batches that I’ve used so that indeed every training item has a valid label sequence (it’s commented out under # For removing all blank images). When that code is uncommented and the blank images are excluded, it still goes to predicting just the blank label after a few epochs.

Sorry for the confusion with this! My original thought process was that I thought that some negative examples would be good so that it didn’t predict labels when there aren’t any. I could, of course, just train a separate CNN that makes the decision of blank or not and then if not blank then feed it to the CTC trained network for generating labels.

The ctc loss value for a case is infinite and then for the all the other examples after that it gives loss as ‘nan’. I realised that at times the input length is less than the target length and that might be a problem so I padded the Input with blank characters.
But the loss is still getting inf value when input length = target length +1.
Is there a limit to the ctc loss that the ratio between the input len and target len should be more than 1.
Thank you for you time.

There is an option to zero infinite losses in the nightlies/next version of PyTorch.

Best regards

Thomas

Just a note that I’ve managed to solve the issue I was having. Using https://github.com/Holmeyoung/crnn-pytorch as a working version compared to mine, the issue was that when I was porting over the dataset from the Keras implementation, text was in the shape of (N_batches, max_sequence_length). Any texts shorter than max_sequence_length (in my case 12), were padded with blank labels. In contrast, the working version in pytorch has text as a 1-dimensional tensor that’s the summed lengths of all texts in the batches.

Hope this helps anyone having a similar error.

1 Like

I had the same problem. I can confirm that flattening targets into a 1D array instead of padding works. Maybe we can submit an issue on Github. That requires a reproducible example. Maybe someone could build one using an ASR model on an opensource dataset. Thank you @mmcauliffe for the solution.