I’m trying to implement a simple transformer using CTC loss. I’m using a batch size of 1. My inputs are masked as follows:
mask = ~mask.any(axis=2)
mask = mask.unsqueeze(-1).expand(-1, -1, self.embedding_dim).bool()
positional_encodings = get_positional_encoding(video_frames.size(1), self.embedding_dim).to(video_frames.device)
video_frames = video_frames + positional_encodings
masked_video_frames = video_frames * mask.float()
transformer_output = self.transformer_encoder( # expects [batch, seq_len, emb_dim]
masked_video_frames,
)
letter_predictions = self.ctc_head(transformer_output)
return letter_predictions, mask
When computing the CTC loss on the dataloader loop like follows:
video_frames, target, mask = data_batch
optimizer.zero_grad()
letter_predictions, frame_mask = model(video_frames.to(device), mask.to(device))
log_probs = nn.functional.log_softmax(letter_predictions, dim=2)
# Compute CTC loss
frame_mask = frame_mask.any(dim=2).unsqueeze(-1)
masked_log_probs = log_probs * frame_mask.float()
input_lengths = torch.sum(~frame_mask, dim=0)
target_concatenated = torch.cat(target, dim=0).to(device)
target_length = len(target_concatenated)
target_lengths = torch.tensor([target_length], dtype=torch.long).to(device)
print(f"Log probs shape: {masked_log_probs.shape}")
print(f"Input lengths shape: {input_lengths.shape}")
print(f"Target concatenated: {target_concatenated}")
print(f"Target concatenated shape: {target_concatenated.shape}")
print(f"Target lengths shape: {target_lengths.shape}")
loss = criterion(masked_log_probs, target_concatenated, input_lengths, target_lengths)
I got: “RuntimeError: target_lengths must be of size batch_size”.
Now my shapes are as follows:
Log probs shape: torch.Size([1, 152, 58])
Input lengths shape: torch.Size([152, 1])
Target concatenated: tensor([21, 23, 17], device=‘cuda:0’)
Target concatenated shape: torch.Size([3])
Target lengths shape: torch.Size([1])
I’m stuck with this, sorry if is too beginner question but can’t found what is happening. My target lenghts shape is already 1, same as my batch size.
Thanks in advance for any help.