Handwriting recognition model struggles to predict

Hi everyone, I’m Leonida.

I’m fairly new to PyTorch and I’m learning by trial & errors and using tutorials. Until now I have only built simple models, but now I am trying to make a CNRR to read human handwriting (input images are grayscale).

My model is the following (there are an attention layer for RNN and skips connections in CNN):

class HandwritingRecognitionModel(nn.Module):
    def __init__(self, num_classes):
        super(HandwritingRecognitionModel, self).__init__()
        self.cnn = nn.Sequential(
            self._conv_block(1, 32, 3, 1),
            nn.MaxPool2d(2, 2),
            self._conv_block(32, 64, 3, 1),
            nn.MaxPool2d(2, 2),
            self._conv_block(64, 128, 3, 1),
            self._conv_block(128, 128, 3, 1),
            nn.MaxPool2d(2, 2),
            self._conv_block(128, 256, 3, 1),
            self._conv_block(256, 256, 3, 1),
            nn.MaxPool2d(2, 2),
        self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, 512)
        self.output = nn.Linear(512, num_classes + 1)  # +1 for CTC blank 

    def _conv_block(self, in_channels, out_channels, kernel_size, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1),

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()
        x = x.view(b, c, h * w).permute(0, 2, 1)
        x, _ = self.lstm(x)

        x = self.fc(x)
        x = nn.functional.relu(x)
        x = nn.functional.dropout(x, 0.2)

        x = self.output(x)

        return nn.functional.log_softmax(x, dim=2)

Images and labels in dataset have different shapes, so I need to pad them (CTC blank character number is 79) with collate_fn:

def collate_fn(batch):
    images, labels = zip(*batch)

    max_height = max(img.shape[1] for img in images)
    max_width = max(img.shape[2] for img in images)
    padded_images = torch.ones(len(images), 1, max_height, max_width)
    for i, img in enumerate(images):
        padded_images[i, :, :img.shape[1], :img.shape[2]] = img

    max_label_length = max(len(label) for label in labels)
    padded_labels = torch.full((len(labels), max_label_length), len(char_set), dtype=torch.long)  # Use blank as padding
    label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
    for i, label in enumerate(labels):
        padded_labels[i, :len(label)] = label

    label_mask = torch.arange(max_label_length)[None, :] < label_lengths[:, None]

    return padded_images, padded_labels, label_lengths

For dataset definition, I defined my CustomCTCDataset which extends Dataset; then:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers = 8, pin_memory = True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, num_workers=8, pin_memory = True)
#same for val_loader

Finally, this is my train method were I use, also, tqdm:

def train(model, loader, optimizer, criterion, device):
    total_loss = 0
    start_time = time.time()
    pbar = tqdm(total=len(loader), desc=f"Training")
    for batch_idx, (data, target, target_lengths) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        target_lengths = target_lengths.to(device)
        output = model(data)
        output_log_softmax_perm = output.permute(1, 0, 2)
        input_lengths = torch.full(size=(data.size(0),), fill_value=output.size(1), dtype=torch.long, device=device)
        loss = criterion(output_log_softmax_perm, target, input_lengths, target_lengths)

        total_loss += loss.item()

        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    epoch_time = time.time() - start_time
    return total_loss / len(loader), epoch_time

Optimizer is fairly standard: Adam with learning rate of 0.01. I use, also, a ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.75) scheduler.

Training lasted 75 epochs with both validation and training loss constantly dcreasing. At the end, val_loss = 0.05.

I was satisfied by the result, so I starded testing. With a simple input, model had to predict "A MOVE to stop Mr. Gaitskell from ", but predicted “tnnnnttntntn; I (hY),thbg !” plus a long sequence of blanks. This is the decode ,method:

def decode_prediction(model, image, char_set):
    if not isinstance(image, torch.Tensor):
        image = torch.from_numpy(image).float()
    if image.dim() == 2:
        image = image.unsqueeze(0).unsqueeze(0) 
    elif image.dim() == 3:
        image = image.unsqueeze(0)  
    with torch.no_grad():
        output = model(image)

    output = output.permute(1, 0, 2)

    pred_indices = torch.argmax(output, dim=0).squeeze()

    char_list = [char_set[idx] for idx in pred_indices if idx < len(char_set)]

    decoded_chars = []
    prev_char = None
    for char in char_list:
        if char != len(char_set): 
        prev_char = char

    decoded_string = ''.join(decoded_chars)
    return decoded_string

Why such a loose prediction? Where do I mistake?

Thank you.