How to calculate cross entropy loss along the last axis?

I’m new to pytorch and is trying to train a model with cross entropy loss.
the output of my model is of size [miniBatchSize, n, m]
and label is of size [miniBatchSize, n]
where M is the number of categories, label elements are ints between 0 and M - 1
my current approach is

labels = torch.eye(m)[labels]
outputs = pointnet(inputs)
outputs = F.log_softmax(outputs, -1)
loss = labels * outputS
loss = loss.sum()
loss.backward()

and I have absolutely no idea if it’s gonna work correctly or not. :slightly_frowning_face:
are there any functions that can do the job for me? thx in advance.

That looks incorrect. The easiest way to solve it is to reshape your output and target to tensors to (bs * n, c) and (bs *n) where bs is the batch size, n the same as your n and c the number of classes:

labels = labels.reshape(-1)
outputs = outputs.reshape(outputs.shape[0] * outputs.shape[1], -1)

Then you compute the normal cross entropy loss:

loss_fn = CrossEntropyLoss()
loss = loss_fn(outputs, labels)

There is also a multi-dimensional version of CrossEntropyLoss, but unless your dimensions are in the order it expects, the ordinary one is easier to use.

1 Like

That’s exactly what I’m looking for, thanks! :grinning: