CRNN with CTC loss predicts only blank labels

Hello,

I am trying to train CRNN with CTC loss. But, after some iterations, the model predicts only blank labels. I have tried solutions provided to similar problems. But, none worked in my case.

I am providing code, Colab notebook, and dataset.

Any help will be really appreciated.

Thanks in advance.


import os
import sys
import cv2
import tqdm
import glob
import torch
import torchvision
from torch import nn
from PIL import Image
from itertools import groupby
import matplotlib.pyplot as plt
from collections import OrderedDict
import torchvision.transforms as transforms
from torch.nn.modules.pooling import MaxPool2d
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.utils.data import Dataset, DataLoader

cwd = os.getcwd()

"""# Model"""

class BiLSTM(torch.nn.Module):
    def __init__(self, n_In, n_hidden, n_Out):
        super(BiLSTM, self).__init__()
        self.rnn = torch.nn.LSTM(n_In, n_hidden, bidirectional=True)
        self.embedding = torch.nn.Linear(n_hidden * 2, n_Out)

    def forward(self, input):
        output, _ = self.rnn(input)
        T, b, h = output.size()
        t_rec = output.view(T * b, h)
        output = self.embedding(output)
#         output = sigmoid(output)
#         m = nn.Sigmoid()
        output = output.view(T, b, -1)
        return output

# n_In = 512
# n_hidden = 30
# n_Out = 62
# rnn = BiLSTM(n_In, n_hidden, n_Out)

class CRNN(torch.nn.Module):
    def __init__(self, n_classes):
        self.n_classes = n_classes
        super(CRNN, self).__init__()
        self.cnn = torch.nn.Sequential(OrderedDict([
            
            ("conv0" , torch.nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)),
            ("relu0" , torch.nn.ReLU()),
            ("pooling0" , torch.nn.MaxPool2d(kernel_size = 2, stride = 2)),

            ("conv1" , torch.nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)),
            ("relu1" , torch.nn.ReLU()),
            ("pooling1" , torch.nn.MaxPool2d(kernel_size = 2, stride = 2)),

            ("conv2" , torch.nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)),

            ("batchnorm2" , torch.nn.BatchNorm2d(256)),
            ("relu2" , torch.nn.ReLU()),

            ("conv3" , torch.nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)),
            ("relu3" , torch.nn.ReLU()),

            ("pooling2" , torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))),

            ("conv4" , torch.nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, padding = 1)),

            ("batchnorm4" , torch.nn.BatchNorm2d(512)),
            ("relu4" , torch.nn.ReLU()),

            ("conv5" , torch.nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1)),
            ("relu5" , torch.nn.ReLU()),

            ("pooling3" , torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))),

            ("conv6" , torch.nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 2, stride = 1, padding=0)),
            ("batchnorm6" , torch.nn.BatchNorm2d(512)),
            ("relu6" , torch.nn.ReLU())

        ]) )

        # self.cnn = self.conv_block
        self.lstm = torch.nn.Sequential( BiLSTM(512, 256, 256),
                                        BiLSTM(256, 256, n_classes)    )

        self.sigmoid = nn.Sigmoid()

    
    def forward(self, input):
        conv_output = self.cnn(input)
        b, c, h, w = conv_output.size()
        conv_output = conv_output.squeeze(2)
        # print(f"conv_output.shape : {conv_output.shape}")
        conv_output = conv_output.permute(2, 0, 1)  # [w, b, c]
#         print(f"conv_output.shape : {conv_output.shape}")
        lstm_op = self.lstm(conv_output)
        lstm_op = self.sigmoid(lstm_op)
        # output = lstm_op.transpose(1,0) 
        return lstm_op

blank_label = 0
alphabets = "~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'.- "
n_classes = len(alphabets)
model = CRNN(n_classes)
print(model)

# cwd

# files = glob.glob(cwd + "/cropped_data/*.jpg")
# files = glob.glob(cwd + "/Dataset/out/*.jpg")
files = glob.glob(cwd + "/Dataset/dataset1/test/*.jpg")
files.sort()
# files

"""# Utility Functions"""

# alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxys0123456789-"

alphabet_list = [a for a in alphabets]

alphabet_dict = {}
for i,c in list(enumerate(alphabet_list)):
    alphabet_dict[c]=i

def encode(string):
    output_list = []
    for char in string:
        output_list.append(alphabet_dict[char])
    return output_list
        
def decode(List):
    string = ""
    for element in List:
        string+=alphabets[int(element)]
    return string

def clean(string):
    new_string = ""
    for i in range(len(string)):
        if string[i] != '-' and (not (i > 0 and string[i - 1] == string[i])):
            new_string+=string[i]
    return new_string

def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
    dim = None
    (h, w) = image.shape[:2]
    if width is None and height is None:
        return image
    if width is None:
        r = height / float(h)
        dim = (int(w * r), height)
    else:
        r = width / float(w)
        dim = (width, int(h * r))
    resized = cv2.resize(image, dim, interpolation = inter)
    return resized

def run_model(img):
    image = image_resize(img, height = 32)
    print(image.shape)

    # plt.imshow(image)
    # plt.show()
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image_t = torch.Tensor(gray_image)
    image_t = torch.unsqueeze(image_t, 2)
    image_t = image_t.permute([2, 0, 1])
    image_t = torch.unsqueeze(image_t, 0)
    print(f"image_t.shape : {image_t.shape}")
    op = model.forward(image_t)
    return op

def draw_tsmaps(image, output):
    image = image_resize(image, height = 200)
    NOStamps = len(output)
    h,w,_ = image.shape
    distance = w / NOStamps
    y1 = 0
    y2 = h-1
    for stamp in range(NOStamps):
        x1 = int(stamp * distance)
        cv2.line(image,(x1,y1),(x1,y2),(0,255,0),1)
        
        image = cv2.putText(image, output[stamp], (x1,h//2), cv2.FONT_HERSHEY_SIMPLEX, 
                           1, (255,0,0), 1, cv2.LINE_AA)

    return image

def predict(frame, dev=False):
    if not dev:
        output = run_model(frame)
    else:
        output = model.forward(frame)
#     print(output.shape)
#     n_o_t,_,_ = output.shape
    _, output = output.max(2)
    output = output.data
    output = output.reshape(-1)
    output = output.tolist()
    print(output)
    output = decode(output)
    clean_output = clean(output)
    return output,clean_output

for file in files[:10]:
    image = cv2.imread(file)
    output, clean_output = predict(image)

    n_o_t = len(output)
    image = draw_tsmaps(image, output)
    plt.imshow(image)
    plt.show()

    print(f"{n_o_t} results")

    print(f"output : {output}")
    print(f"clean_output : {clean_output}")

    print("_______________________________________________________________________")

"""# Defining Dataloaders"""

def custom_collate(batch):
    label_padding_value = alphabet_dict[' ']
    width = [item['img'].shape[2] for item in batch]
    indexes = [item['idx'] for item in batch]
    imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], 
                       max(width)], dtype=torch.float32)
    for idx, item in enumerate(batch):
        try:
            imgs[idx, :, :, 0:item['img'].shape[2]] = item['img']
        except:
            print(imgs.shape)
    item = {'img': imgs, 'idx':indexes}
    if 'label' in batch[0].keys():
        labels = [item['label'] for item in batch]
        len_labels = [len(label) for label in labels]
        max_size_label = max(len_labels)
        for label in labels:
            label.extend([label_padding_value] * (max_size_label - len(label)))
        item['label'] = labels


        item['label'] = torch.Tensor(item['label'])
    return item
    
class OCRDataset(Dataset):
    def __init__(self, opt):
        super(Dataset, self).__init__()
        self.path = os.path.join(opt['path'], opt['imgdir'])
        self.images = os.listdir(self.path)
        self.nSamples = len(self.images)
        f = lambda x: os.path.join(self.path, x)
        self.imagepaths = list(map(f, self.images))
       	transform_list =  [#transforms.Resize((128,128)),
                           transforms.Grayscale(1),
                            transforms.ToTensor(), 
                            transforms.Normalize((0.5,), (0.5,))]
        self.transform = transforms.Compose(transform_list)
        self.collate_fn = custom_collate

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imagepath = self.imagepaths[index]
        imagefile = os.path.basename(imagepath)
        img = Image.open(imagepath)
        if self.transform is not None:
            img = self.transform(img)
        item = {'img': img, 'idx':index}
        item['label'] = imagefile.split('_')[0]
        item['label'] = encode(item['label'])
        return item

cwd

batch_size = 16
opt_train = {
    "path" : cwd,
    # "imgdir" : "namewise"
    "imgdir" : "Dataset/dataset1/train"
}

opt_valid = {
    "path" : cwd,
    # "imgdir" : "namewise"
    "imgdir" : "Dataset/dataset1/valid"
}

opt_test = {
    "path" : cwd,
    # "imgdir" : "namewise"
    "imgdir" : "Dataset/dataset1/test"
}

train_data = OCRDataset(opt_train)
train_loader = DataLoader(train_data, shuffle=True, drop_last=True, batch_size=batch_size, num_workers=0, collate_fn = custom_collate)


valid_data = OCRDataset(opt_valid)
valid_loader = DataLoader(valid_data, shuffle=False, drop_last=True, batch_size=batch_size, num_workers=0, collate_fn = custom_collate)


test_data = OCRDataset(opt_test)
test_loader = DataLoader(test_data, shuffle=False, drop_last=True, batch_size=batch_size, num_workers=0, collate_fn = custom_collate)

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)#, momentum=0, dampening=0, weight_decay=0, nesterov=False)
optimizer = torch.optim.Adadelta(model.parameters())

def average(losses):
    return sum(losses)/len(losses)

def calc_accuracy(Y_train, Y_pred):
    train_correct = 0
    train_total = 0
    _, max_index = torch.max(Y_pred, dim=2)
    for i in range(batch_size):
        raw_prediction = list(max_index[:, i].detach().cpu().numpy())
        prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != blank_label])
        if len(prediction) == len(Y_train[i]) and torch.all(prediction.eq(Y_train[i])):
            train_correct += 1
        train_total += 1
    accuracy = train_correct / train_total
    return accuracy

def train():
    losses = []
    accuracies = []
    for L in tqdm.tqdm(train_loader):
        X_train = (L['img'])
        Y_train = (L['label'])
        # print(f"Y_train: {Y_train}")

        batch_size = X_train.shape[0]  
        optimizer.zero_grad()
        Y_pred = model(X_train).cuda()
        Y_pred_0 = Y_pred.permute(1, 0, 2)  
        input_lengths = torch.IntTensor([len(t) for t in Y_pred_0])
        target_lengths = torch.IntTensor([len(t) for t in Y_train])
        Y_pred = Y_pred.log_softmax(dim=2)
        loss = criterion(Y_pred, Y_train, input_lengths, target_lengths)
        losses.append(loss)
        loss.backward()
        optimizer.step()
        accuracy = calc_accuracy(Y_train, Y_pred)
        accuracies.append(accuracy)

    avg_loss = average(losses)
    avg_accuracy = average(accuracies)

    print(f"Training Loss : {avg_loss}")
    print(f"Training Accuracy : {avg_accuracy}")

def valid():

    losses = []
    accuracies = []
    for L in tqdm.tqdm(train_loader):

        X_train = (L['img'])
        Y_train = (L['label'])

        batch_size = X_train.shape[0]  
        optimizer.zero_grad()
        Y_pred = model(X_train).cuda()
        Y_pred_0 = Y_pred.permute(1, 0, 2) 
        input_lengths = torch.IntTensor([len(t) for t in Y_pred_0])
        target_lengths = torch.IntTensor([len(t) for t in Y_train])
        Y_pred = Y_pred.log_softmax(dim=2)
        loss = criterion(Y_pred, Y_train, input_lengths, target_lengths)
        losses.append(loss)

        accuracy = calc_accuracy(Y_train, Y_pred)
        accuracies.append(accuracy)

    avg_loss = average(losses)
    avg_accuracy = average(accuracies)

    print(f"Validation Loss : {avg_loss}")
    print(f"Validation Accuracy : {avg_accuracy}")

epochs = 200
for epoch in range(epochs):
    train()
    valid()

for file in files[:10]:
    image = cv2.imread(file)
    output, clean_output = predict(image)

    n_o_t = len(output)
    image = draw_tsmaps(image, output)
    plt.imshow(image)
    plt.show()

    print(f"{n_o_t} results")

    print(f"output : {output}")
    print(f"clean_output : {clean_output}")

    print("_______________________________________________________________________")

Link of the Colab notebook:
https://drive.google.com/file/d/10d_RLxEkzl8F2yCvalNKF-zmULyaY6Pg/view?usp=sharing

Link of the Dataset:
https://drive.google.com/drive/folders/1a0U3j-fkxW1nUvYQiqi327xVMBkjRk-K?usp=sharing

The problem is solved.

I was using a sigmoid after the last layer. Removing it solved the problem.

Here is the architecture of the current working model:

CRNN(
  (cnn): Sequential(
    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu0): ReLU()
    (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU()
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU()
    (pooling2): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU()
    (conv5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu5): ReLU()
    (pooling3): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
    (conv6): Conv2d(512, 512, kernel_size=(2, 2), stride=(1, 1))
    (batchnorm6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu6): ReLU()
  )
  (lstm): Sequential(
    (0): BiLSTM(
      (rnn): LSTM(512, 256, bidirectional=True)
      (embedding): Linear(in_features=512, out_features=256, bias=True)
    )
    (1): BiLSTM(
      (rnn): LSTM(256, 256, bidirectional=True)
      (embedding): Linear(in_features=512, out_features=67, bias=True)
    )
  )
)