Runtime error: target_lengths must be of size batch_size with CTC loss using a batch size of 1

I’m trying to implement a simple transformer using CTC loss. I’m using a batch size of 1. My inputs are masked as follows:

        mask = ~mask.any(axis=2)
        mask = mask.unsqueeze(-1).expand(-1, -1, self.embedding_dim).bool()
        
        positional_encodings = get_positional_encoding(video_frames.size(1), self.embedding_dim).to(video_frames.device)
        video_frames = video_frames + positional_encodings
            
        masked_video_frames = video_frames * mask.float()
               
        transformer_output = self.transformer_encoder( # expects [batch, seq_len, emb_dim]
            masked_video_frames,
        ) 
        letter_predictions = self.ctc_head(transformer_output)
        return letter_predictions, mask

When computing the CTC loss on the dataloader loop like follows:

        video_frames, target, mask = data_batch 
        optimizer.zero_grad()    

        letter_predictions, frame_mask = model(video_frames.to(device), mask.to(device))
        log_probs = nn.functional.log_softmax(letter_predictions, dim=2) 
        
        # Compute CTC loss
        frame_mask = frame_mask.any(dim=2).unsqueeze(-1)  
        masked_log_probs = log_probs * frame_mask.float()  
        
        input_lengths = torch.sum(~frame_mask, dim=0)  
        target_concatenated = torch.cat(target, dim=0).to(device)         
        target_length = len(target_concatenated)  
        target_lengths = torch.tensor([target_length], dtype=torch.long).to(device)
        
        print(f"Log probs shape: {masked_log_probs.shape}")
        print(f"Input lengths shape: {input_lengths.shape}")
        print(f"Target concatenated: {target_concatenated}")
        print(f"Target concatenated shape: {target_concatenated.shape}")
        print(f"Target lengths shape: {target_lengths.shape}")
       
        loss = criterion(masked_log_probs, target_concatenated, input_lengths, target_lengths)

I got: “RuntimeError: target_lengths must be of size batch_size”.

Now my shapes are as follows:

Log probs shape: torch.Size([1, 152, 58])
Input lengths shape: torch.Size([152, 1])
Target concatenated: tensor([21, 23, 17], device=‘cuda:0’)
Target concatenated shape: torch.Size([3])
Target lengths shape: torch.Size([1])

I’m stuck with this, sorry if is too beginner question but can’t found what is happening. My target lenghts shape is already 1, same as my batch size.

Thanks in advance for any help.

I was passing log probs as a tensor of shape [batch_size, sequence_length, features], while the CTC loss expects the batch size on the second dimension.