Batch Size suddenly changes while training a CRNN with ctc loss

Can someone please point out why this is happening? I spent hours looking for similar errors and I couldn’t find an answer

Error:

Batch No. 0
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 63]), lengths: torch.Size([185]) 
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
 Logit Lengths : torch.Size([185])  Target : torch.Size([185])
__________________________________________________________________________
Batch No. 1
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 66]), lengths: torch.Size([185]) 
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
 Logit Lengths : torch.Size([185])  Target : torch.Size([185])
__________________________________________________________________________
Batch No. 2
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 68]), lengths: torch.Size([185]) 
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
 Logit Lengths : torch.Size([185])  Target : torch.Size([185])
__________________________________________________________________________
Batch No. 3
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 65]), lengths: torch.Size([185]) 
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
 Logit Lengths : torch.Size([185])  Target : torch.Size([185])
__________________________________________________________________________
Batch No. 4
images: torch.Size([185, 3, 200, 64]), Targets: torch.Size([185, 70]), lengths: torch.Size([185]) 
LOGIT SHAPE torch.Size([16, 185, 113])
LOGIT SHAPE torch.Size([16, 185, 113])
 Logit Lengths : torch.Size([185])  Target : torch.Size([185])
__________________________________________________________________________
Batch No. 5
images: torch.Size([171, 3, 200, 64]), Targets: torch.Size([171, 60]), lengths: torch.Size([171]) 
LOGIT SHAPE torch.Size([16, 171, 113])
LOGIT SHAPE torch.Size([16, 171, 113])
 Logit Lengths : torch.Size([185])  Target : torch.Size([171])
__________________________________________________________________________

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-18-14bff90b3d84> in <cell line: 5>()
     28 
     29         # Calculate the CTC loss
---> 30         loss = ctc_loss(logits, targets, logit_lengths, target_lengths)
     31         i += 1
     32         optimizer.zero_grad()

3 frames

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/loss.py in forward(self, log_probs, targets, input_lengths, target_lengths)
   1768 
   1769     def forward(self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor) -> Tensor:
-> 1770         return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
   1771                           self.zero_infinity)
   1772 

/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)
   2654             blank=blank, reduction=reduction, zero_infinity=zero_infinity
   2655         )
-> 2656     return torch.ctc_loss(
   2657         log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity
   2658     )

RuntimeError: input_lengths must be of size batch_size

Train loop:


# Training loop
for epoch in range(num_epochs):
    crnn_model.train()
    total_loss = 0.0
    i = 0
    for images, targets, target_lengths in train_loader:
        print("Batch No.",i)
        images = images.to(device)
        targets = targets.to(device)

        print(f"images: {images.shape}, Targets: {targets.shape}, lengths: {target_lengths.shape} ")

        logits = crnn_model(images) # Outputs should be [TimeStep, Batch, NumClass]
        logit_lengths = torch.LongTensor([logits.size(0)] * batch_size[0])

        print(f"LOGIT SHAPE {logits.shape}")
        # logits = logits.transpose(0, 1)
        # print(f"LOGIT SHAPE {logits.shape}")
        print(f" Logit Lengths : {logit_lengths.shape}  Target : {target_lengths.shape}")
        print("__________________________________________________________________________")

        # Calculate the CTC loss
        loss = ctc_loss(logits, targets, logit_lengths, target_lengths)
        i += 1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    writer.add_scalar('Loss/Train', avg_loss, epoch)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')

    # Validation
    if (epoch + 1) % 1 == 0:
        crnn_model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for val_images, val_targets, val_target_lengths in val_loader:
                val_images = val_images.to(device)
                val_targets = val_targets.to(device)

                val_logits = crnn_model(val_images)
                val_logit_lengths = torch.LongTensor([val_logits.size(0)] * batch_size[1])

                val_logits = torch.nn.functional.log_softmax(val_logits, dim=2)



                val_loss += ctc_loss(val_logits, val_targets, val_logit_lengths, val_target_lengths).item()

                _, predicted_labels = torch.max(val_logits, 2)
                predicted_labels = ["".join([dataset.char_list[c] for c in row if c != 0]) for row in predicted_labels.cpu().numpy()]

                for pred, target in zip(predicted_labels, val_targets.cpu().numpy()):
                    distance = levenshtein_distance(pred, "".join([dataset.char_list[c] for c in target if c != 0]))

                    writer.add_scalar('LevenshteinDistance/Validation', distance, epoch)

        avg_val_loss = val_loss / len(val_loader)
        writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
        crnn_model.train()

        print(f'Validation Loss: {avg_val_loss:.4f}')

torch.save(crnn_model.state_dict(), 'chk_pts/crnn_model.pth')
writer.close()

dataloader:


class DataLoader(object):

    def __init__(self, ds, batch_size=(16, 16), validation_split=0.2,
                 shuffle=True, seed=42, device='cpu', blank_label=9999):
        assert isinstance(ds, SharadaDataset)
        assert isinstance(batch_size, tuple)
        assert isinstance(validation_split, float)
        assert isinstance(shuffle, bool)
        assert isinstance(seed, int)
        assert isinstance(device, str)

        self.ds = ds
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.shuffle = shuffle
        self.seed = seed
        self.device = device
        self.blank_label = blank_label

    def  __call__(self):

        dataset_size = len(self.ds)
        indices = list(range(dataset_size))
        split = int(np.floor(self.validation_split * dataset_size))

        if self.shuffle:
            np.random.seed(self.seed)
            np.random.shuffle(indices)
        train_indices, val_indices = indices[split:], indices[:split]

        # Creating PT data samplers and loaders:
        train_sampler = SubsetRandomSampler(train_indices)
        valid_sampler = SubsetRandomSampler(val_indices)

        # Dataloader
        train_loader = DataLoader(self.ds, batch_size=self.batch_size[0],
                                  sampler=train_sampler, collate_fn=self.collate_fn)
        validation_loader = DataLoader(self.ds, batch_size=self.batch_size[1],
                                       sampler=valid_sampler, collate_fn=self.collate_fn)

        return train_loader, validation_loader



    def collate_fn(self, batch):
        images, labels = [b.get('image') for b in batch], [b.get('label') for b in batch]
        images = torch.stack(images, 0)
        lengths = [len(label) for label in labels]
        max_label_len = max(lengths)
        targets = []
        for j, label in enumerate(labels):
          temp = [self.ds.char_dict.get(letter) for letter in label]
          temp.extend([self.blank_label] * (max_label_len - len(label)))
          targets.append(torch.tensor(temp))
          # targets.append([torch.tensor([self.ds.char_dict.get(letter) for letter in label]).long()])

        targets = torch.stack(targets, 0)
        lengths = torch.tensor(lengths)

        if self.device == 'cpu':
            dev = torch.device('cpu')
        else:
            dev = torch.device('cuda')

        return images.to(dev), targets.to(dev), lengths.to(dev)