What pytorch loss function should I use for 1D sequence per-element classification?

I have a network that can generate predictions with shape [8, 32, 400]. The first number “8” is simply the batch size, and the 2nd number “32” is the length of my sequence, and the 3rd number “400” is the total number of classes for each element in my sequence.

Basically, I want to perform per-element classification for each sequence (like a simplified version of semantic segmentation). So I define my labels to have shape [8, 32], here, 8 is the batch size and 32 is the sequence length (or total number of elements in this sequence)

What pytorch loss function should I use here?

You could use nn.CrossEntropyLoss and pass the model outputs as [batch_size, nb_classes, seq_len] to this loss function while the targets are fine in[batch_size, seq_len]. Make sure to .permute the model output to create the desired shape.

1 Like