Hi everyone, I’m Leonida.
I’m fairly new to PyTorch and I’m learning by trial & errors and using tutorials. Until now I have only built simple models, but now I am trying to make a CNRR to read human handwriting (input images are grayscale).
My model is the following (there are an attention layer for RNN and skips connections in CNN):
class HandwritingRecognitionModel(nn.Module):
def __init__(self, num_classes):
super(HandwritingRecognitionModel, self).__init__()
self.cnn = nn.Sequential(
self._conv_block(1, 32, 3, 1),
nn.MaxPool2d(2, 2),
self._conv_block(32, 64, 3, 1),
nn.MaxPool2d(2, 2),
self._conv_block(64, 128, 3, 1),
self._conv_block(128, 128, 3, 1),
nn.MaxPool2d(2, 2),
self._conv_block(128, 256, 3, 1),
self._conv_block(256, 256, 3, 1),
nn.MaxPool2d(2, 2),
)
self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
self.fc = nn.Linear(512, 512)
self.output = nn.Linear(512, num_classes + 1) # +1 for CTC blank
def _conv_block(self, in_channels, out_channels, kernel_size, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.2)
)
def forward(self, x):
x = self.cnn(x)
b, c, h, w = x.size()
x = x.view(b, c, h * w).permute(0, 2, 1)
x, _ = self.lstm(x)
x = self.fc(x)
x = nn.functional.relu(x)
x = nn.functional.dropout(x, 0.2)
x = self.output(x)
return nn.functional.log_softmax(x, dim=2)
Images and labels in dataset have different shapes, so I need to pad them (CTC blank character number is 79) with collate_fn
:
def collate_fn(batch):
images, labels = zip(*batch)
max_height = max(img.shape[1] for img in images)
max_width = max(img.shape[2] for img in images)
padded_images = torch.ones(len(images), 1, max_height, max_width)
for i, img in enumerate(images):
padded_images[i, :, :img.shape[1], :img.shape[2]] = img
max_label_length = max(len(label) for label in labels)
padded_labels = torch.full((len(labels), max_label_length), len(char_set), dtype=torch.long) # Use blank as padding
label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
for i, label in enumerate(labels):
padded_labels[i, :len(label)] = label
label_mask = torch.arange(max_label_length)[None, :] < label_lengths[:, None]
return padded_images, padded_labels, label_lengths
For dataset definition, I defined my CustomCTCDataset
which extends Dataset
; then:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers = 8, pin_memory = True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, num_workers=8, pin_memory = True)
#same for val_loader
Finally, this is my train method were I use, also, tqdm
:
def train(model, loader, optimizer, criterion, device):
model.train()
total_loss = 0
start_time = time.time()
pbar = tqdm(total=len(loader), desc=f"Training")
for batch_idx, (data, target, target_lengths) in enumerate(loader):
data, target = data.to(device), target.to(device)
target_lengths = target_lengths.to(device)
optimizer.zero_grad()
output = model(data)
output_log_softmax_perm = output.permute(1, 0, 2)
input_lengths = torch.full(size=(data.size(0),), fill_value=output.size(1), dtype=torch.long, device=device)
loss = criterion(output_log_softmax_perm, target, input_lengths, target_lengths)
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.update(1)
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
pbar.close()
epoch_time = time.time() - start_time
return total_loss / len(loader), epoch_time
Optimizer is fairly standard: Adam with learning rate of 0.01
. I use, also, a ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.75)
scheduler.
Training lasted 75 epochs with both validation and training loss constantly dcreasing. At the end, val_loss = 0.05
.
I was satisfied by the result, so I starded testing. With a simple input, model had to predict "A MOVE to stop Mr. Gaitskell from ", but predicted “tnnnnttntntn; I (hY),thbg !” plus a long sequence of blanks. This is the decode ,method:
def decode_prediction(model, image, char_set):
if not isinstance(image, torch.Tensor):
image = torch.from_numpy(image).float()
if image.dim() == 2:
image = image.unsqueeze(0).unsqueeze(0)
elif image.dim() == 3:
image = image.unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(image)
output = output.permute(1, 0, 2)
pred_indices = torch.argmax(output, dim=0).squeeze()
char_list = [char_set[idx] for idx in pred_indices if idx < len(char_set)]
decoded_chars = []
prev_char = None
for char in char_list:
if char != len(char_set):
decoded_chars.append(char)
prev_char = char
decoded_string = ''.join(decoded_chars)
return decoded_string
Why such a loose prediction? Where do I mistake?
Thank you.