What is the right way to use CTCLoss for HTR

Hi, I would like some help on the use of CTCLoss for handwritten text recognition task.

The problem that I am facing now is that although the loss decreases pretty rapidly at the start, it levels out and no longer decrease at around 10% into an epoch. When I inspect the output of the forward pass of the model, it seems like the model decided that the best way to minimize CTCLoss is to predict only the first character, and then blanks for the the rest of the sequence.

Originally, based on this other post, it seems that the problem was padding with the blank labels in my encoding from label strings to label vectors (because I couldn’t batch them if they are of varying length). Now I have changed to concatenating all labels into a single vector as stated in the documentation. However, the loss still level out and the model predicts gibberish.

Below is the training code and how I used CTCLoss.

for epoch in range(1):

    training_loss = np.array([])
    validation_loss = np.array([])

    with tqdm(total=len(train_set), position=0, leave=True, desc="Epoch {}".format(epoch)) as pbar:
        for i, data in enumerate(train_loader, 0):
            # zero parameter gradients
            optimizer.zero_grad()

            images, labels, label_lengths = data['image'].to(device), data['label'], data['label_length'].to(device)
            # Forward pass
            outputs = model(images)

            # Mostly 64, except last batch
            batch_size = len(data['label_length'])
            input_lengths = torch.full(size=(batch_size,), fill_value=128, dtype=torch.long)

            # Concate labels into a single tensor of size (sum(target_lengths))
            targets = torch.from_numpy(tokenizer.encode(''.join(labels)))

            # Loss calculation
            loss = ctc_criterion(outputs, targets.to(device), input_lengths, label_lengths)
            # Backpropagate loss
            loss.backward()
            # Gradient descent
            optimizer.step()

            training_loss = np.append(training_loss, loss.item())
            pbar.update(train_loader.batch_size)
            pbar.set_description(desc="Epoch {}, Train Loss {}".format(epoch, loss.item()))

Am I doing anything wrong with the way I use CTC? Let me know if more information is required.

Thanks.

The output variable has to be a probability vector. Are you applying the logsoftmax on outputs variable before pushing it into the CTCLoss function? Additionally, can you print out the CTC_Criterion function?

Hi @charan_Vjy,

Yes I am using logsoftmax as the last layer. Below is my model:

class HTRModel(nn.Module):
    def __init__(self, imgSize=(1024, 128), charBaseSize=97):
        super(HTRModel, self).__init__()

        self.cnn = nn.Sequential(
            # Expected input size (B, 1, W, H)
            # Expected output size (B, 16, W/2, H/2)
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2,stride=2),

            # Expected input size (B, 16, W/2, H/2)
            # Expected output size (B, 32, W/4, H/4)
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2,stride=2),

            nn.Dropout(p=0.2),

            # Expected input size (B, 32, W/4, H/4)
            # Expected output size (B, 48, W/8, H/8)
            nn.Conv2d(in_channels=32, out_channels=48, kernel_size=3, padding=1),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2,stride=2),

            nn.Dropout(p=0.2),

            # Expected input size (B, 48, W/8, H/8)
            # Expected output size (B, 64, W/8, H/8)
            nn.Conv2d(in_channels=48, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(inplace=True),

            nn.Dropout(p=0.2),

            # Expected input size (B, 64, W/8, H/8)
            # Expected output size (B, 80, W/8, H/8)
            nn.Conv2d(in_channels=64, out_channels=80, kernel_size=3, padding=1),
            nn.BatchNorm2d(80),
            nn.LeakyReLU(inplace=True)
        )

        # Expected RNN vector size (S=W/8, B, num_features=H/8*num_output_channels)
        # Expected output size (S, B, hidden_size*2). hidden_size*2 because of bi-directional.
        rnn_input_size = int(imgSize[1]/8 * 80)
        self.rnn = nn.LSTM(input_size=rnn_input_size, hidden_size=256, num_layers=4, dropout=0.5, bidirectional=True)

        # Expected input size (S, B, hidden_size*2)
        # Expected output size (S, B, charBaseSize)
        self.dense = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=512, out_features=charBaseSize),
            nn.LogSoftmax(dim=2)
        )

    def forward(self, x):
        output = self.cnn(x)
        # Reorder Batch dimension to second, width dimension (which is now the sequence dimension) to first. Squeeze the channel and height dimension
        output = output.transpose(0,1).transpose(0,-1)
        output = output.reshape(output.size(0), output.size(1), -1)
        output = self.rnn(output)
        output = self.dense(output[0])
        return output

Also this is how I instantiate the optimizer and ctc_criterion.

from torch import optim

# Send model to gpu
model = HTRModel()
model.to(device)

# Instantiate Loss and Optimizers
ctc_criterion = nn.CTCLoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001, momentum=0.01)

ctc_criterion is just an instantiation of nn.CTCLoss with default parameters.