Hello,
I am trying to train CRNN with CTC loss. But, after some iterations, the model predicts only blank labels. I have tried solutions provided to similar problems. But, none worked in my case.
I am providing code, Colab notebook, and dataset.
Any help will be really appreciated.
Thanks in advance.
import os
import sys
import cv2
import tqdm
import glob
import torch
import torchvision
from torch import nn
from PIL import Image
from itertools import groupby
import matplotlib.pyplot as plt
from collections import OrderedDict
import torchvision.transforms as transforms
from torch.nn.modules.pooling import MaxPool2d
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.utils.data import Dataset, DataLoader
cwd = os.getcwd()
"""# Model"""
class BiLSTM(torch.nn.Module):
def __init__(self, n_In, n_hidden, n_Out):
super(BiLSTM, self).__init__()
self.rnn = torch.nn.LSTM(n_In, n_hidden, bidirectional=True)
self.embedding = torch.nn.Linear(n_hidden * 2, n_Out)
def forward(self, input):
output, _ = self.rnn(input)
T, b, h = output.size()
t_rec = output.view(T * b, h)
output = self.embedding(output)
# output = sigmoid(output)
# m = nn.Sigmoid()
output = output.view(T, b, -1)
return output
# n_In = 512
# n_hidden = 30
# n_Out = 62
# rnn = BiLSTM(n_In, n_hidden, n_Out)
class CRNN(torch.nn.Module):
def __init__(self, n_classes):
self.n_classes = n_classes
super(CRNN, self).__init__()
self.cnn = torch.nn.Sequential(OrderedDict([
("conv0" , torch.nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)),
("relu0" , torch.nn.ReLU()),
("pooling0" , torch.nn.MaxPool2d(kernel_size = 2, stride = 2)),
("conv1" , torch.nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)),
("relu1" , torch.nn.ReLU()),
("pooling1" , torch.nn.MaxPool2d(kernel_size = 2, stride = 2)),
("conv2" , torch.nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)),
("batchnorm2" , torch.nn.BatchNorm2d(256)),
("relu2" , torch.nn.ReLU()),
("conv3" , torch.nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)),
("relu3" , torch.nn.ReLU()),
("pooling2" , torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))),
("conv4" , torch.nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, padding = 1)),
("batchnorm4" , torch.nn.BatchNorm2d(512)),
("relu4" , torch.nn.ReLU()),
("conv5" , torch.nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1)),
("relu5" , torch.nn.ReLU()),
("pooling3" , torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))),
("conv6" , torch.nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 2, stride = 1, padding=0)),
("batchnorm6" , torch.nn.BatchNorm2d(512)),
("relu6" , torch.nn.ReLU())
]) )
# self.cnn = self.conv_block
self.lstm = torch.nn.Sequential( BiLSTM(512, 256, 256),
BiLSTM(256, 256, n_classes) )
self.sigmoid = nn.Sigmoid()
def forward(self, input):
conv_output = self.cnn(input)
b, c, h, w = conv_output.size()
conv_output = conv_output.squeeze(2)
# print(f"conv_output.shape : {conv_output.shape}")
conv_output = conv_output.permute(2, 0, 1) # [w, b, c]
# print(f"conv_output.shape : {conv_output.shape}")
lstm_op = self.lstm(conv_output)
lstm_op = self.sigmoid(lstm_op)
# output = lstm_op.transpose(1,0)
return lstm_op
blank_label = 0
alphabets = "~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'.- "
n_classes = len(alphabets)
model = CRNN(n_classes)
print(model)
# cwd
# files = glob.glob(cwd + "/cropped_data/*.jpg")
# files = glob.glob(cwd + "/Dataset/out/*.jpg")
files = glob.glob(cwd + "/Dataset/dataset1/test/*.jpg")
files.sort()
# files
"""# Utility Functions"""
# alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxys0123456789-"
alphabet_list = [a for a in alphabets]
alphabet_dict = {}
for i,c in list(enumerate(alphabet_list)):
alphabet_dict[c]=i
def encode(string):
output_list = []
for char in string:
output_list.append(alphabet_dict[char])
return output_list
def decode(List):
string = ""
for element in List:
string+=alphabets[int(element)]
return string
def clean(string):
new_string = ""
for i in range(len(string)):
if string[i] != '-' and (not (i > 0 and string[i - 1] == string[i])):
new_string+=string[i]
return new_string
def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
dim = None
(h, w) = image.shape[:2]
if width is None and height is None:
return image
if width is None:
r = height / float(h)
dim = (int(w * r), height)
else:
r = width / float(w)
dim = (width, int(h * r))
resized = cv2.resize(image, dim, interpolation = inter)
return resized
def run_model(img):
image = image_resize(img, height = 32)
print(image.shape)
# plt.imshow(image)
# plt.show()
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image_t = torch.Tensor(gray_image)
image_t = torch.unsqueeze(image_t, 2)
image_t = image_t.permute([2, 0, 1])
image_t = torch.unsqueeze(image_t, 0)
print(f"image_t.shape : {image_t.shape}")
op = model.forward(image_t)
return op
def draw_tsmaps(image, output):
image = image_resize(image, height = 200)
NOStamps = len(output)
h,w,_ = image.shape
distance = w / NOStamps
y1 = 0
y2 = h-1
for stamp in range(NOStamps):
x1 = int(stamp * distance)
cv2.line(image,(x1,y1),(x1,y2),(0,255,0),1)
image = cv2.putText(image, output[stamp], (x1,h//2), cv2.FONT_HERSHEY_SIMPLEX,
1, (255,0,0), 1, cv2.LINE_AA)
return image
def predict(frame, dev=False):
if not dev:
output = run_model(frame)
else:
output = model.forward(frame)
# print(output.shape)
# n_o_t,_,_ = output.shape
_, output = output.max(2)
output = output.data
output = output.reshape(-1)
output = output.tolist()
print(output)
output = decode(output)
clean_output = clean(output)
return output,clean_output
for file in files[:10]:
image = cv2.imread(file)
output, clean_output = predict(image)
n_o_t = len(output)
image = draw_tsmaps(image, output)
plt.imshow(image)
plt.show()
print(f"{n_o_t} results")
print(f"output : {output}")
print(f"clean_output : {clean_output}")
print("_______________________________________________________________________")
"""# Defining Dataloaders"""
def custom_collate(batch):
label_padding_value = alphabet_dict[' ']
width = [item['img'].shape[2] for item in batch]
indexes = [item['idx'] for item in batch]
imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1],
max(width)], dtype=torch.float32)
for idx, item in enumerate(batch):
try:
imgs[idx, :, :, 0:item['img'].shape[2]] = item['img']
except:
print(imgs.shape)
item = {'img': imgs, 'idx':indexes}
if 'label' in batch[0].keys():
labels = [item['label'] for item in batch]
len_labels = [len(label) for label in labels]
max_size_label = max(len_labels)
for label in labels:
label.extend([label_padding_value] * (max_size_label - len(label)))
item['label'] = labels
item['label'] = torch.Tensor(item['label'])
return item
class OCRDataset(Dataset):
def __init__(self, opt):
super(Dataset, self).__init__()
self.path = os.path.join(opt['path'], opt['imgdir'])
self.images = os.listdir(self.path)
self.nSamples = len(self.images)
f = lambda x: os.path.join(self.path, x)
self.imagepaths = list(map(f, self.images))
transform_list = [#transforms.Resize((128,128)),
transforms.Grayscale(1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))]
self.transform = transforms.Compose(transform_list)
self.collate_fn = custom_collate
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
imagepath = self.imagepaths[index]
imagefile = os.path.basename(imagepath)
img = Image.open(imagepath)
if self.transform is not None:
img = self.transform(img)
item = {'img': img, 'idx':index}
item['label'] = imagefile.split('_')[0]
item['label'] = encode(item['label'])
return item
cwd
batch_size = 16
opt_train = {
"path" : cwd,
# "imgdir" : "namewise"
"imgdir" : "Dataset/dataset1/train"
}
opt_valid = {
"path" : cwd,
# "imgdir" : "namewise"
"imgdir" : "Dataset/dataset1/valid"
}
opt_test = {
"path" : cwd,
# "imgdir" : "namewise"
"imgdir" : "Dataset/dataset1/test"
}
train_data = OCRDataset(opt_train)
train_loader = DataLoader(train_data, shuffle=True, drop_last=True, batch_size=batch_size, num_workers=0, collate_fn = custom_collate)
valid_data = OCRDataset(opt_valid)
valid_loader = DataLoader(valid_data, shuffle=False, drop_last=True, batch_size=batch_size, num_workers=0, collate_fn = custom_collate)
test_data = OCRDataset(opt_test)
test_loader = DataLoader(test_data, shuffle=False, drop_last=True, batch_size=batch_size, num_workers=0, collate_fn = custom_collate)
criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)#, momentum=0, dampening=0, weight_decay=0, nesterov=False)
optimizer = torch.optim.Adadelta(model.parameters())
def average(losses):
return sum(losses)/len(losses)
def calc_accuracy(Y_train, Y_pred):
train_correct = 0
train_total = 0
_, max_index = torch.max(Y_pred, dim=2)
for i in range(batch_size):
raw_prediction = list(max_index[:, i].detach().cpu().numpy())
prediction = torch.IntTensor([c for c, _ in groupby(raw_prediction) if c != blank_label])
if len(prediction) == len(Y_train[i]) and torch.all(prediction.eq(Y_train[i])):
train_correct += 1
train_total += 1
accuracy = train_correct / train_total
return accuracy
def train():
losses = []
accuracies = []
for L in tqdm.tqdm(train_loader):
X_train = (L['img'])
Y_train = (L['label'])
# print(f"Y_train: {Y_train}")
batch_size = X_train.shape[0]
optimizer.zero_grad()
Y_pred = model(X_train).cuda()
Y_pred_0 = Y_pred.permute(1, 0, 2)
input_lengths = torch.IntTensor([len(t) for t in Y_pred_0])
target_lengths = torch.IntTensor([len(t) for t in Y_train])
Y_pred = Y_pred.log_softmax(dim=2)
loss = criterion(Y_pred, Y_train, input_lengths, target_lengths)
losses.append(loss)
loss.backward()
optimizer.step()
accuracy = calc_accuracy(Y_train, Y_pred)
accuracies.append(accuracy)
avg_loss = average(losses)
avg_accuracy = average(accuracies)
print(f"Training Loss : {avg_loss}")
print(f"Training Accuracy : {avg_accuracy}")
def valid():
losses = []
accuracies = []
for L in tqdm.tqdm(train_loader):
X_train = (L['img'])
Y_train = (L['label'])
batch_size = X_train.shape[0]
optimizer.zero_grad()
Y_pred = model(X_train).cuda()
Y_pred_0 = Y_pred.permute(1, 0, 2)
input_lengths = torch.IntTensor([len(t) for t in Y_pred_0])
target_lengths = torch.IntTensor([len(t) for t in Y_train])
Y_pred = Y_pred.log_softmax(dim=2)
loss = criterion(Y_pred, Y_train, input_lengths, target_lengths)
losses.append(loss)
accuracy = calc_accuracy(Y_train, Y_pred)
accuracies.append(accuracy)
avg_loss = average(losses)
avg_accuracy = average(accuracies)
print(f"Validation Loss : {avg_loss}")
print(f"Validation Accuracy : {avg_accuracy}")
epochs = 200
for epoch in range(epochs):
train()
valid()
for file in files[:10]:
image = cv2.imread(file)
output, clean_output = predict(image)
n_o_t = len(output)
image = draw_tsmaps(image, output)
plt.imshow(image)
plt.show()
print(f"{n_o_t} results")
print(f"output : {output}")
print(f"clean_output : {clean_output}")
print("_______________________________________________________________________")
Link of the Colab notebook:
https://drive.google.com/file/d/10d_RLxEkzl8F2yCvalNKF-zmULyaY6Pg/view?usp=sharing
Link of the Dataset:
https://drive.google.com/drive/folders/1a0U3j-fkxW1nUvYQiqi327xVMBkjRk-K?usp=sharing