CrossEntropyLoss on Sequence of Class Probabilities

Wondering what the best practices are. Model gives out [batch_size, sequence_length, class_probabilities]. And I would like to apply CrossEntropyLoss on the class probabilities and sum up that loss over the sequence length before back prop.

So my question is should I loop over the sequence length and sum the losses, or is there a better way?

I’m unsure if I understand the use case correctly, but would this work?

batch_size = 2
sequence_length = 3
class_probabilities = 4

logits = torch.randn([batch_size, sequence_length, class_probabilities], requires_grad=True)
targets = torch.randint(0, class_probabilities, (batch_size, sequence_length))

criterion = nn.CrossEntropyLoss(reduction="none")

logits = logits.permute(0, 2, 1) # permute into [batch_size, nb_classes, seq_len]
loss = criterion(logits, targets)
print(loss.shape)
# torch.Size([2, 3]) # [batch_size, sequence_length]
loss = loss.sum(1) # this will sum in the sequence_length dim

# reduce the loss further e.g. via mean or call backward with the gradient argument
loss.mean().backward()

That actually worked. Thanks @ptrblck.

1 Like