I have a network that can generate predictions with shape [8, 32, 400]. The first number “8” is simply the batch size, and the 2nd number “32” is the length of my sequence, and the 3rd number “400” is the total number of classes for each element in my sequence.

Basically, I want to perform per-element classification for each sequence (like a simplified version of semantic segmentation). So I define my labels to have shape [8， 32]， here, 8 is the batch size and 32 is the sequence length (or total number of elements in this sequence)

What pytorch loss function should I use here?