TLDR. I want to compute cross entropy for the batch, and in the batch there can be different masking strategy for each element in a batch.

I am having quite interesing question. Consider that my output returns a tensor of shape `[B, N, C]`

. `B`

is the number of batches, `N`

is the sequence length and `C`

is the length of vocabulary.

To compute cross entropy loss I want to get each 2nd word, or each 3rd word, etc. depending on the original sequence length. However, I am not sure whether it is possible to parallelize this operation.