OCR Model not Learning, blank label predicted each epoch | Code review

Hello Everyone,

I am attempting to create a model for OCR …
Input : Image (Hindi Word) → Output : Each character recognition.

I have written a very basic RCNN model. Here are the current observation…

  1. The models seems to be not learning after multiple epochs. It is predicting only “blank”-0 label for every word after each epochs.
  2. Interestingly, the loss seems to decreasing but the output from the decoder is still “blank”-0

Note : I am taking each time step’s hidden state and passing it to the linear layer
(batch Size , Input size ,featuresize=numDirns*hiddenSize) → (batch Size , Input size,class length) → Log soft max

Kindly review and give your valuable suggestions.

# -*- coding: utf-8 -*-
"""Image_to_text.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1TfeTXVL5jk0vAzM_dqNA7yVLSrvyjloo

# Importing Dependencies
"""

import torch
import torch.nn as nn
import torch.optim as optim
import cv2
import os
import matplotlib.pyplot as plt
import random
import numpy as np
from matplotlib import font_manager
import time
from  torch.utils.data import Dataset,DataLoader 
import pandas as pd
from torchvision import transforms
from skimage import io
from tqdm import tqdm
from collections import OrderedDict

"""Mounting Drive"""

from google.colab import drive
drive.mount('/gdrive')

# Commented out IPython magic to ensure Python compatibility.
# %cd /gdrive/MyDrive/Capstone_project_data/ImgtoText/Cropped_Images/cropped_data

!ls

#!unzip 'Sample Train.zip'

if torch.cuda.is_available():
  MyDevice = torch.device("cuda")
else:
  MyDevice = torch.device("cpu")

print(MyDevice)

"""# Data Loader """

## All Hindi Alphabets ## 
all_hindi_alpha = [" "]+[chr(i) for i in range(2304,2432)]
all_hindi_alpha = {all_hindi_alpha[i]:i for i in range(len(all_hindi_alpha))}
print(all_hindi_alpha)
print(len(all_hindi_alpha))

with open("annotations.txt") as fh:
  allLineList = fh.readlines()
fh.close()
labelGenerator = (allLineList[i].split('\t')[1].strip('\n') for i in range(0,len(allLineList)-1))
labelGenerator = list(labelGenerator)

print(labelGenerator)

for eWord in labelGenerator:
  for eStr in eWord:
    print(eStr)
  break

"""Encode Hindi Words"""

def gt_rep(word, letter2index,max_str_len = None, device = 'cpu'):
  gt_rep = torch.zeros([max_str_len, 1], dtype=torch.long).to(device)
  if len(word)<max_str_len:
    diff = max_str_len-len(word)
    word = ''.join((word," "*diff))
  for letter_index, letter in enumerate(word):
    pos = letter2index[letter]
    gt_rep[letter_index][0] = pos
  return gt_rep

class MyCollateClass():
  def __init__(self,dim=1):
    self.dim = dim

  def stackTensors(self,itera):
    return torch.stack(itera['image'])

  def padTensors(self,tensorLabels,maxStrLen):
    finList,sequenceLens = [],[]
    for eTensor in tensorLabels:
      sequenceLens.append(len(eTensor))
      if eTensor.size()[0]<maxStrLen:
        diff = abs(eTensor.size()[0]-maxStrLen)
        finalTensor = torch.cat([eTensor,torch.zeros(diff)],dim=0).int()
      else:finalTensor = eTensor.int()
      finList.append(finalTensor)
    finTensor = torch.stack(finList)   
    sequenceLens = torch.Tensor(sequenceLens).int()
    return finTensor,sequenceLens


  def PadCollate(self,batch):
    
    def _get_max_sentance_len(LabelList):
      return max(list(eTensor.size()[0] for eTensor in LabelList))

    finalDict = {}
    Imglabel_list = list(((eDict['image'],eDict['label']) for eDict in batch))
    ImgTensorList,LabelList = list(zip(*Imglabel_list))
    maxStr_Len = _get_max_sentance_len(LabelList)
    LabelTensor,seqLens = self.padTensors(LabelList,maxStr_Len)
    ImgTensor = torch.stack(ImgTensorList)
    finalDict = {"Images":ImgTensor,"Label":LabelTensor,"SeqLength":seqLens}
    #print(ImgTensor.shape,LabelTensor.shape)
    return finalDict

  
  def __call__(self,batch):
    return self.PadCollate(batch)

class HindiTextDataset(Dataset):
  def __init__(self,LabelList = None,RootDirectory = None,transform=None,vocabList=None):
    self.LabelList = LabelList
    print("iter")
    self.root_dir = RootDirectory
    self.transform = transform
    self.vocabList = vocabList

  def __len__(self):
    return len(self.LabelList)
  
  def _get_letter_to_index(self,idx):
    strList = []
    for eChar in self.LabelList[idx]:
      strList.append(self.vocabList.get(eChar))
    return torch.Tensor(strList).int()

  def __getitem__(self,idx):
    img_tensor = io.imread(''.join([self.root_dir,str(idx),'.jpg']))
    img_tensor = self.transform(img_tensor)
    img_tensor = transforms.functional.resize(img_tensor,(128,128))
    label_tensor = self._get_letter_to_index(idx)
    sample = {'image':img_tensor,'label':label_tensor}
    return sample

transform_batch = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

TextDataset = HindiTextDataset(labelGenerator,"cropped_dir/",transform = transform_batch,vocabList=all_hindi_alpha)

for i in range(len(TextDataset)):
  print("Image Size : ",TextDataset[i]["image"].shape,"Label Size : ",TextDataset[i]["label"])
  if i==3:
    break

"""Custom Dataset Loader """

batch_size = 4
dataloader1 = DataLoader(TextDataset, batch_size=batch_size,
                        shuffle=True, num_workers=0,collate_fn=MyCollateClass())

for i,data in enumerate(dataloader1):
  print(data['Images'].size(),data["Label"].size(),data["SeqLength"])
  if i>3:
    break

"""Show Sample Data"""

for ind,data in enumerate(dataloader1):
  if ind>0:break
  fig = plt.figure()
  nrows,ncols = batch_size//2,batch_size//2
  ax = fig.subplots(nrows,ncols)
  counter = 0
  for i in range(nrows):
    for j in range(ncols):
      ax[i,j].imshow(data['Images'][counter][0][0:60][0:100])
      counter+=1

"""# ENCODER PART"""

class FeatureExtractor(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = nn.Sequential(
        nn.Conv2d(3,64,3), ## (N,3,128,128) -> (N,64,124,124)
        nn.ReLU(),
        nn.MaxPool2d(2,2), ## (N,64,62,62)
        nn.Conv2d(64,256,3), ## (N,64,62,62) -> (N,256,60,60)
        nn.ReLU(),
        nn.MaxPool2d(2,2), ## (N,256,30,30)
        nn.Conv2d(256,512,5),
        nn.ReLU(),
        nn.MaxPool2d(2,2),
        nn.MaxPool2d(2,2)

    )

  def forward(self,x):
    ff_pass = self.encoder(x)
    return ff_pass

obj = FeatureExtractor()
for data in dataloader1:
  dd = obj(data["Images"])
  print(dd.size())
  0/0

"""# DECODER PART"""

## Extracted feature from CNN will act as input for Encoder-Decoder Model , each column x channel Depth on an input for the encoder-decoder model ## 
class LSTM_Net(nn.Module):
  def __init__(self,input_size=None,batch_size=None,hidden_size=None,output_size=None,numLayers=1,numDirns=1):
    self.hidden_size = hidden_size
    self.batch_size = batch_size
    self.numLayers = numLayers
    self.numDirns = numDirns

    super().__init__()
    self.hidden_size = hidden_size
    self.lstm_cell = nn.LSTM(input_size,hidden_size,num_layers=self.numLayers,batch_first=True,bidirectional =True)
    self.h2o = nn.Linear(self.numDirns*hidden_size,output_size)
    self.F = nn.ReLU()
    self.SoftMAX = nn.Softmax(dim=2)
   

  def forward(self,input,hidden,training=True):
    out,hidden = self.lstm_cell(input,hidden)
    #out = out.view(out.size()[1],out.size()[0],self.numDirns,self.hidden_size).permute(2,0,1,3)
    #out = torch.cat((out[0],out[1]),dim=2)
    output = self.F(self.h2o(out))
    if training==False:
      outputt = self.SoftMAX(output)
    else:
      outputt = nn.functional.log_softmax(output,dim=2)
    return outputt
  
  def init_hiddenlayer(self,device='cpu'):
    return (torch.zeros(self.numLayers*self.numDirns,self.batch_size,self.hidden_size).to(device),torch.zeros(self.numLayers*self.numDirns,self.batch_size,self.hidden_size).to(device))

"""# Inference & Accuracy"""

def _computeAccuracy(source,target):
  def _convertTarget_toList(target):
    print("Target : ",target)
    for e_array in np.asarray(target.tolist()) :
      newList = np.delete(e_array,np.where(e_array==0)).tolist()
      yield newList

  def _convertSource_toList(source):
    source = source.detach()
    source_idx = torch.argmax(source,dim=2)
    print(source_idx)
    finalList = []
    print("source_idx:",source_idx)
    for eList in source_idx:
      idx = torch.where(eList==0)[0]
      if len(idx)==0:continue
      eList,idx = eList.int().tolist(),idx.int().tolist()
      collapsedList = []
      temp = 0
      for i in idx:
        if i ==idx[-1]:ss = list(OrderedDict.fromkeys(eList[temp:]))
        else:
          ss = list(OrderedDict.fromkeys(eList[temp:i]))
        ss = [eVal for eVal in ss if eVal !=0]
        if len(ss)!=0:
          collapsedList.extend(ss)
        temp = i
      finalList.append(collapsedList)
    yield finalList

  
  targetList = list(_convertTarget_toList(target))
  sourceList = list(_convertSource_toList(source))[0]
  score =0
  partialScore = 0
  for eSource,eTarget in zip(sourceList,targetList):
    if eSource ==[]:continue
    if set(eSource).issubset(eTarget) and len(eSource)==len(eTarget):score+=1
    elif set(eSource).issubset(eTarget):partialScore+=1
  
  return partialScore/source.size()[0],score/source.size()[0]

"""# BatchTraining """

def batchTrain(TextDataSet,EncoderModel=None,DecoderModel = None,lossFn=None,optimFn=None,scheduler=None,batchSize=None,epochs=1,device='cpu'):
  loadedData = DataLoader(TextDataset,batch_size=batchSize,shuffle=True,collate_fn=MyCollateClass())
  maxLoss = 100000
  ## Intialize both the models ##
  for i in range(epochs):
    cummLoss = 0
    for ind,data in enumerate(loadedData):
      if data['Images'].size()[0]!=batchSize:continue
      optimFn.zero_grad()
      img_tensor = data['Images'].to(device)
      targets = data['Label'].to(device)
      target_lengths = data['SeqLength']
      encoder_fpass = EncoderModel(img_tensor).permute(0,2,3,1)
      encoder_fpass_new = encoder_fpass.reshape(encoder_fpass.shape[0],encoder_fpass.shape[1],encoder_fpass.shape[2]*encoder_fpass.shape[3])
      hidden = DecoderObj.init_hiddenlayer(device)
      decoder_fpass = DecoderObj(encoder_fpass_new,hidden).to(device).permute(1,0,2)
      input_lengths = torch.Tensor([decoder_fpass.size()[0]] * batchSize).int()
      loss =  lossFn(decoder_fpass,targets.cpu(),input_lengths.cpu(),target_lengths.cpu())
      loss.backward()
      optimFn.step()
      cummLoss+=loss.item()*batchSize
      _computeAccuracy(decoder_fpass,targets)


    
    #scheduler.step(loss)
    loss_per_epoch = cummLoss/batchSize
    if loss_per_epoch<maxLoss:
      maxLoss = loss_per_epoch
      torch.save({
          'epoch': i,
          'encoder_state_dict': EncoderModel.state_dict(),
          'decoder_state_dict': DecoderModel.state_dict(),
          'optimizer_state_dict': optimFn.state_dict(),
          'loss': loss_per_epoch,
          }, "model_Text_Recognition.pt")
    print("Loss Per Epoch {} for epoch {} ".format(cummLoss/batchSize,i+1)) 
    partialScore,fullScore = _computeAccuracy(decoder_fpass,targets)
    print("Partial Score --> {} , Full Score --> {}".format (partialScore,fullScore))

## HYPERPARAMETERS
batchSize = 20
num_layers = 1
num_dirn = 2
hidden_size = 200
lr = 0.005

ctc_loss = nn.CTCLoss(zero_infinity=True)
EncoderObj = FeatureExtractor().to(MyDevice)
DecoderObj = LSTM_Net(input_size=6*512,batch_size=batchSize,hidden_size=hidden_size,output_size=len(all_hindi_alpha),numLayers=num_layers,numDirns=num_dirn).to(MyDevice)
optimFn = optim.Adam(list(EncoderObj.parameters())+list(DecoderObj.parameters()),lr=lr)

batchTrain(TextDataset,EncoderModel=EncoderObj,DecoderModel=DecoderObj,lossFn=ctc_loss,optimFn=optimFn,batchSize=batchSize,epochs=50,device=MyDevice)

checkpoint = torch.load('model_Text_Recognition.pt')
EncoderObj.load_state_dict(checkpoint['encoder_state_dict'])
DecoderObj.load_state_dict(checkpoint['decoder_state_dict'])
optimFn.load_state_dict(checkpoint['optimizer_state_dict'])