Training a single embedding using masking

True. Solution as in nn.CrossEntropyLoss() for text with multiple dimension

However it is misleading to have the number of classes as a dimension.
It make sense to have the output as (batch_size, dimension1, nb_classes) which will be
reduced to (batch_size, dimension1).

It is misleading because from application perspective the output should be (batch_size, dimension1, nb_classes). You hve to add another functional operation to clear the difference.

For example check this bert training code. They added “output.transpose(1, 2)” operation before computing the loss.