Hi folks,
I have a bit of clarifying (technical) question regarding using cross-entropy with seq2seq models?
B = batch_size
S = sequence_length (subscripts, i=input, o=output)
E = embedding_size
V = vocabulary_size
C = num_classes/labels
X = [B, S_i, V] # our input data
Y = [B, S_o, C] # our targets, consider them one-hot encoded
X_enc = encoder(X) # X_enc = [B, S_i, E]
X_dec = decoder(X_enc, Y) # X_dec = [B, S_o, E]
loss = nn.CrossEntropyLoss()(X_dec, Y) # this shoudn't work due to shape mismatch
How should cross entropy work in this case? There’s a shape mismatch between X_dec != Y?
Plus, to my understanding cross entropy accepts inputs of shape [B, C] and targets of shape [B, 1], what am I missing here?
Is there another way how is this computed?
One way I was thinking about going at it would have been the following, but I’m not sure if its correct?
That’s not the case as described in the docs.
If you are using class indices the target should have the shape [B] given your model output shape. If the target contains probabilities it should match the model output shape, so [B, C].
Yes correct, my mistake, I’ve corrected it in the post.
But, do you have any idea how is cross entropy computed when both outputs = [B, S_o, E] and targets_one_hot = [B, S_o, C] are 3-dim? This seems quite common loss in seq2seq tasks but since this case is not covered in the docs (unless I’m missing something), is there another reference or MWE?
The docs also cover this use case and will treat the output as [batch_size, nb_classes, additional_dimension]. Since the target is also 3-dimensional is must have the same shape, so C==E in your case.
Here is an example:
As I was searching around I saw that in many cases another approached that was used was to flatten the 3-dim tensors from [B, S, E] to [B*S, E] and then compute the loss, which I presume is similar to what you described.
As I understood from your MWE there are 2 key points here.
In order for cross-entropy to work with 3-dim tensors we should have nb_classes as dim=1 and let cross-entropy compute the loss over the nb_classes, e.g., both output and target to have shape [batch_size, nb_classes, seq_len]
Otherwise if we have to flatten dimensions we should flatten across batch_size*seq_len, therefore our tensors should have shape [batch_size, seq_len, nb_classes], i.e., my case.