Can someone please point out why this is happening? I spent hours looking for similar errors and I couldn’t find an answer
Error:
Batch No. 0
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 63]), lengths: torch.Size([185])
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
Logit Lengths : torch.Size([185]) Target : torch.Size([185])
__________________________________________________________________________
Batch No. 1
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 66]), lengths: torch.Size([185])
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
Logit Lengths : torch.Size([185]) Target : torch.Size([185])
__________________________________________________________________________
Batch No. 2
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 68]), lengths: torch.Size([185])
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
Logit Lengths : torch.Size([185]) Target : torch.Size([185])
__________________________________________________________________________
Batch No. 3
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 65]), lengths: torch.Size([185])
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
Logit Lengths : torch.Size([185]) Target : torch.Size([185])
__________________________________________________________________________
Batch No. 4
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 70]), lengths: torch.Size([185])
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
Logit Lengths : torch.Size([185]) Target : torch.Size([185])
__________________________________________________________________________
Batch No. 5
images: torch.Size([171, 3, 200, 64]), Targets: torch.Size([171, 60]), lengths: torch.Size([171])
LOGIT SHAPE torch.Size([16, 171, 113])
LOGIT SHAPE torch.Size([16, 171, 113])
Logit Lengths : torch.Size([185]) Target : torch.Size([171])
__________________________________________________________________________
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-14bff90b3d84> in <cell line: 5>()
28
29 # Calculate the CTC loss
---> 30 loss = ctc_loss(logits, targets, logit_lengths, target_lengths)
31 i += 1
32 optimizer.zero_grad()
3 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
1519
1520 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1528
1529 try:
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/loss.py in forward(self, log_probs, targets, input_lengths, target_lengths)
1768
1769 def forward(self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor) -> Tensor:
-> 1770 return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
1771 self.zero_infinity)
1772
/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)
2654 blank=blank, reduction=reduction, zero_infinity=zero_infinity
2655 )
-> 2656 return torch.ctc_loss(
2657 log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity
2658 )
RuntimeError: input_lengths must be of size batch_size
Train loop:
# Training loop
for epoch in range(num_epochs):
crnn_model.train()
total_loss = 0.0
i = 0
for images, targets, target_lengths in train_loader:
print("Batch No.",i)
images = images.to(device)
targets = targets.to(device)
print(f"images: {images.shape}, Targets: {targets.shape}, lengths: {target_lengths.shape} ")
logits = crnn_model(images) # Outputs should be [TimeStep, Batch, NumClass]
logit_lengths = torch.LongTensor([logits.size(0)] * batch_size[0])
print(f"LOGIT SHAPE {logits.shape}")
# logits = logits.transpose(0, 1)
# print(f"LOGIT SHAPE {logits.shape}")
print(f" Logit Lengths : {logit_lengths.shape} Target : {target_lengths.shape}")
print("__________________________________________________________________________")
# Calculate the CTC loss
loss = ctc_loss(logits, targets, logit_lengths, target_lengths)
i += 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
writer.add_scalar('Loss/Train', avg_loss, epoch)
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')
# Validation
if (epoch + 1) % 1 == 0:
crnn_model.eval()
val_loss = 0.0
with torch.no_grad():
for val_images, val_targets, val_target_lengths in val_loader:
val_images = val_images.to(device)
val_targets = val_targets.to(device)
val_logits = crnn_model(val_images)
val_logit_lengths = torch.LongTensor([val_logits.size(0)] * batch_size[1])
val_logits = torch.nn.functional.log_softmax(val_logits, dim=2)
val_loss += ctc_loss(val_logits, val_targets, val_logit_lengths, val_target_lengths).item()
_, predicted_labels = torch.max(val_logits, 2)
predicted_labels = ["".join([dataset.char_list[c] for c in row if c != 0]) for row in predicted_labels.cpu().numpy()]
for pred, target in zip(predicted_labels, val_targets.cpu().numpy()):
distance = levenshtein_distance(pred, "".join([dataset.char_list[c] for c in target if c != 0]))
writer.add_scalar('LevenshteinDistance/Validation', distance, epoch)
avg_val_loss = val_loss / len(val_loader)
writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
crnn_model.train()
print(f'Validation Loss: {avg_val_loss:.4f}')
torch.save(crnn_model.state_dict(), 'chk_pts/crnn_model.pth')
writer.close()
dataloader:
class DataLoader(object):
def __init__(self, ds, batch_size=(16, 16), validation_split=0.2,
shuffle=True, seed=42, device='cpu', blank_label=9999):
assert isinstance(ds, SharadaDataset)
assert isinstance(batch_size, tuple)
assert isinstance(validation_split, float)
assert isinstance(shuffle, bool)
assert isinstance(seed, int)
assert isinstance(device, str)
self.ds = ds
self.batch_size = batch_size
self.validation_split = validation_split
self.shuffle = shuffle
self.seed = seed
self.device = device
self.blank_label = blank_label
def __call__(self):
dataset_size = len(self.ds)
indices = list(range(dataset_size))
split = int(np.floor(self.validation_split * dataset_size))
if self.shuffle:
np.random.seed(self.seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
# Dataloader
train_loader = DataLoader(self.ds, batch_size=self.batch_size[0],
sampler=train_sampler, collate_fn=self.collate_fn)
validation_loader = DataLoader(self.ds, batch_size=self.batch_size[1],
sampler=valid_sampler, collate_fn=self.collate_fn)
return train_loader, validation_loader
def collate_fn(self, batch):
images, labels = [b.get('image') for b in batch], [b.get('label') for b in batch]
images = torch.stack(images, 0)
lengths = [len(label) for label in labels]
max_label_len = max(lengths)
targets = []
for j, label in enumerate(labels):
temp = [self.ds.char_dict.get(letter) for letter in label]
temp.extend([self.blank_label] * (max_label_len - len(label)))
targets.append(torch.tensor(temp))
# targets.append([torch.tensor([self.ds.char_dict.get(letter) for letter in label]).long()])
targets = torch.stack(targets, 0)
lengths = torch.tensor(lengths)
if self.device == 'cpu':
dev = torch.device('cpu')
else:
dev = torch.device('cuda')
return images.to(dev), targets.to(dev), lengths.to(dev)