Why do we flatten the output in Language Model training?

I am following this tutorial to create a language model with transformers: Language Modeling with nn.Transformer and torchtext — PyTorch Tutorials 2.2.0+cu121 documentation

I have some questions about the way we process the data. So correct me if I am wrong, we do the followings first:

  • read the text and tokenized them in a flat tensor
  • Then we turn them into batches
  • and transform them so that they have equal sequence length

So far, so good.

Let’s assume our vocab has 4817 elements. In each batch, data has a size: torch.Size([30, 20]) with a sequence length of 30 and batch size of 20. But then we flatten the targetin batch into the Tensor with size torch.Size([600]).

Then we run the cross entropy loss:

criterion = nn.CrossEntropyLoss()
loss = criterion(output_flat, targets)

After reshaping the tensor, the model output, output_flat, is of size torch.Size([600, 4817]).

My confusion is why we are training the model like this. Why are we flattening all the targets in the batch here and calculating the loss on the entire batch as a flattened-out object? Should we not calculate the loss one by one for each instance in the batch?