Hello,
(I am aware there are several similar questions, but none of the solutions given helped me to solve my problem.)
I am training a CRNN with a CTCLoss using pytorch. At the first few iterations, the predicted labels are all very similar (random sequences of the same 3-4 characters), although the real labels are not. Then, as it trains, the average length of the predicted sequences decreases, until a point where only blank labels are predicted.
Please use the code below to reproduce my problem (pytorch : 1.3.1, python: 3.6, running on CPU). To ease the reproduction of my problem, the code is generating random tensors as inputs and ground truths (naturally, I am normally using real text images as inputs, but the error is the same). Also, I implemented a ctc_beam_search with tensorflow to visualize the outputs (I normally the pytorch implementation but it requires compiling, and the outputs of both implementions are the same).
CHAR_VECTOR = "0123456789abcdefghijklmnopqrstuvwxyz. -"
# Note that '-' is the blank character of the ctc loss
NUM_CLASSES = len(CHAR_VECTOR)
################
''' CRNN network'''
import torch.nn as nn
import torch
class ConvBnRelu(nn.Module):
def __init__(self, i):
super(ConvBnRelu, self).__init__()
self.ks = [3, 3, 3, 3, 3, 3]
self.ps = [1, 1, 1, 1, 1, 1]
self.ss = [1, 1, 1, 1, 1, 1]
self.out_layers = [64, 64, 128, 128, 256, 256]
self.in_layers = [3, 64, 64, 128, 128, 256]
self.conv = nn.Conv2d(self.in_layers[i], self.out_layers[i], self.ks[i], self.ss[i], self.ps[i])
self.bn = nn.BatchNorm2d(self.out_layers[i])
self.relu = nn.ReLU(inplace=False)
def forward(self, c):
c = self.conv(c)
c = self.bn(c)
c = self.relu(c)
return c
class RecognitionCNN(nn.Module):
def __init__(self):
super(RecognitionCNN, self).__init__()
self.layer0 = ConvBnRelu(0)
self.layer1 = ConvBnRelu(1)
self.layer2 = ConvBnRelu(2)
self.layer3 = ConvBnRelu(3)
self.layer4 = ConvBnRelu(4)
self.layer5 = ConvBnRelu(5)
self.maxpool = nn.MaxPool2d(kernel_size=(2,1), stride=(2,1), padding=(0, 0))
def forward(self, cropped_features):
c = self.layer0(cropped_features)
c = self.layer1(c)
p = self.maxpool(c)
c = self.layer2(p)
p = self.maxpool(c)
c = self.layer3(p)
p = self.maxpool(c)
c = self.layer4(p)
p = self.maxpool(c)
c = self.layer5(p)
output = self.maxpool(c)
return output
class RecognitionBiLSTM(nn.Module):
def __init__(self, input_num = 256, dropout= 0.3, hidden_num=256, weight_decay=1e-5, is_training=True):
super(RecognitionBiLSTM, self).__init__()
self.input_num = input_num
self.hidden_num = hidden_num
self.weight_decay = weight_decay
self.dropout = dropout
self.num_classes = NUM_CLASSES
self.is_training = is_training
self.rnn = nn.LSTM(self.input_num, self.hidden_num, bidirectional=True)
self.fully_connected = nn.Linear(self.hidden_num * 2, self.num_classes)
self.dropout_layer = nn.Dropout(self.dropout)
def forward(self, input):
rnn_out, _ = self.rnn(input)
# rnn_out = self.dropout_layer(rnn_out)
##TODO: remettre dropout
T, b, h = rnn_out.size() #T = max_width, b = num_boxes, h = num_hidden*2
rnn_out = rnn_out.view(T * b, h)
output = self.fully_connected(rnn_out)# (T*b, num_classes)
output = output.view(T, b, -1) # (T, b, num_classes)
return output
class Recognition(nn.Module):
def __init__(self, is_training):
super(Recognition, self).__init__()
self.cnn = RecognitionCNN()
self.rnn = RecognitionBiLSTM(is_training = is_training)
def forward (self, input):
cnn_output = self.cnn(input)
b, ch, h, w = cnn_output.size()
assert h == 1, "the height of cnn_output must be 1"
cnn_output = cnn_output.squeeze(2)
cnn_output = cnn_output.permute(2, 0, 1) # [max_width, num_boxes, num_hidden]
rnn_output = self.rnn(cnn_output)
return rnn_output
#############
'''Criterion'''
from torch.nn import CTCLoss
class RecognitionLoss(nn.Module):
def __init__(self):
super(RecognitionLoss, self).__init__()
self.ctc_loss = CTCLoss(blank=38, reduction='mean', zero_infinity=False) #input, target, input_lengths, target_lengths
def forward(self, pred, target, input_lengths, target_lengths):
pred = pred.log_softmax(2)
loss = self.ctc_loss(pred, target, input_lengths, target_lengths)
return loss
################
'''Compare the decoded output and input.
Note that to ease the running, the tensorflow ctc_beam_search is used (I normally use the pytorch
implementation available at https://github.com/parlance/ctcdecode, but it requires compiling. Also, they both output the
same results'''
def encoded_to_string(encoded_text):
return ''.join([CHAR_VECTOR[i] for i in encoded_text if i != -1])
def tf_decoder(rnn_pred, input_lengths):
import tensorflow as tf
import numpy as np
input_lengths = (input_lengths.detach().numpy()).astype(np.int32)
pred = tf.convert_to_tensor((rnn_pred.cpu().detach().numpy()).astype(np.float32))
decoded, log_prob = tf.nn.ctc_beam_search_decoder(pred, input_lengths, merge_repeated=False)
dense_decoded = tf.sparse_tensor_to_dense(decoded[0], default_value=-1)
with tf.Session() as sess:
dense_decoded = sess.run(dense_decoded)
all_texts = []
for box in dense_decoded:
all_texts.append(encoded_to_string(box))
return all_texts
def print_input_vs_output(pred,target, input_lengths, target_lengths):
tf_outputs = tf_decoder(pred, input_lengths)
decoded_inputs = [encoded_to_string(txt[:length]) for txt,length in zip(target, target_lengths)]
log_txt = ''
for i, (output, input) in enumerate(zip(tf_outputs, decoded_inputs)):
log_txt += 'truth : {} / pred : {}'.format(input,output)
if i < len(tf_outputs)-1:
log_txt += ' \n'
print('----------------------------')
print(log_txt)
print('----------------------------')
#############
''' Train'''
if __name__ == '__main__':
model = Recognition(is_training=True)
criterion = RecognitionLoss()
lr = 0.0001
model = model.to('cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
T = 50 # Input sequence length
N = 16 # Batch size
C = 20 # Number of classes (including blank)
S = 30 # Target sequence length of longest target in batch
S_min = 10 # Minimum target length, for demonstration purposes
for ite in range(1000):
input = torch.randn(N, 3, 32, T).requires_grad_()
target = torch.randint(low=1, high=38, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
pred = model(input)
loss = criterion(pred,target, input_lengths, target_lengths)
print_input_vs_output(pred, target, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Could anyone help solve my problem by any chance? Thanks