How is cross-entropy used in seq2seq models?

Hi folks,
I have a bit of clarifying (technical) question regarding using cross-entropy with seq2seq models?

B = batch_size
S = sequence_length (subscripts, i=input, o=output)
E = embedding_size
V = vocabulary_size
C = num_classes/labels

X = [B, S_i, V]  # our input data

Y = [B, S_o, C] # our targets, consider them one-hot encoded

X_enc = encoder(X) # X_enc = [B, S_i, E]

X_dec = decoder(X_enc, Y) # X_dec = [B, S_o, E]

loss = nn.CrossEntropyLoss()(X_dec, Y) # this shoudn't work due to shape mismatch

How should cross entropy work in this case? There’s a shape mismatch between X_dec != Y?

Plus, to my understanding cross entropy accepts inputs of shape [B, C] and targets of shape [B, 1], what am I missing here?

Is there another way how is this computed?

One way I was thinking about going at it would have been the following, but I’m not sure if its correct?

torch.mean(torch.sum((-Y * F.log_softmax(X_dec, dim=-1)), dim=-1))

That’s not the case as described in the docs.
If you are using class indices the target should have the shape [B] given your model output shape. If the target contains probabilities it should match the model output shape, so [B, C].

Yes correct, my mistake, I’ve corrected it in the post.

But, do you have any idea how is cross entropy computed when both outputs = [B, S_o, E] and targets_one_hot = [B, S_o, C] are 3-dim? This seems quite common loss in seq2seq tasks but since this case is not covered in the docs (unless I’m missing something), is there another reference or MWE?

The docs also cover this use case and will treat the output as [batch_size, nb_classes, additional_dimension]. Since the target is also 3-dimensional is must have the same shape, so C==E in your case.
Here is an example:

batch_size = 2
nb_classes = 3
seq_len = 4
output = torch.randn(batch_size, nb_classes, seq_len, requires_grad=True)
target = torch.randint(0, nb_classes, (batch_size, seq_len))
target = torch.nn.functional.one_hot(target, num_classes=nb_classes).permute(0, 2, 1).float()
print(target.shape)
# torch.Size([2, 3, 4])

criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)

The criterion will compute the cross-entropy loss for each sample in seq_len in the same way it’s doing so for each sample in the batch dimension.

Awesome, thanks a lot for clarifying this!

As I was searching around I saw that in many cases another approached that was used was to flatten the 3-dim tensors from [B, S, E] to [B*S, E] and then compute the loss, which I presume is similar to what you described.

Yes, but be careful with the dimensions.
Using my example you would need to permute the target before flattening:

batch_size = 2
nb_classes = 3
seq_len = 4
output = torch.randn(batch_size, nb_classes, seq_len, requires_grad=True)
target = torch.randint(0, nb_classes, (batch_size, seq_len))
target = torch.nn.functional.one_hot(target, num_classes=nb_classes).permute(0, 2, 1).float()
print(target.shape)
# torch.Size([2, 3, 4])

criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)

# with flattened tensors treating the sequence dimension as separate samples in the batch dimension
print(output.shape)
# torch.Size([2, 3, 4]) = [batch_size, nb_classes, seq_len]
output = output.permute(0, 2, 1).contiguous().view(-1, nb_classes)
print(output.shape)
# torch.Size([8, 3]) = [batch_size * seq_len, nb_classes]

print(target.shape)
# torch.Size([2, 3, 4]) = [batch_size, nb_classes, seq_len]
target = target.permute(0, 2, 1).contiguous().view(-1, nb_classes)
print(target.shape)
# torch.Size([8, 3]) = [batch_size * seq_len, nb_classes]

loss2 = criterion(output, target)

print(loss - loss2)
# tensor(0., grad_fn=<SubBackward0>)

Thanks a lot.

As I understood from your MWE there are 2 key points here.

  1. In order for cross-entropy to work with 3-dim tensors we should have nb_classes as dim=1 and let cross-entropy compute the loss over the nb_classes, e.g., both output and target to have shape [batch_size, nb_classes, seq_len]

  2. Otherwise if we have to flatten dimensions we should flatten across batch_size*seq_len, therefore our tensors should have shape [batch_size, seq_len, nb_classes], i.e., my case.