Cross Entropy Loss get predicted class

Hi all,

I am using in my multiclass text classification problem the cross entropy loss. But I have been confused. My targets are in [0, c-1] format. How can I obtain the predicted class? An example will be helpful, since cross entropy loss is using softmax why I don’t take probabilities as output with sum =1?

nn.CrossEntropyLoss expects logits, as internally F.log_softmax and nn.NLLLoss will be used.
If you want to get the predicted class, you could simply use torch.argmax:

output = model(input)
pred = torch.argmax(output, dim=1)

I assume dim1 is representing the classes. If not, you should change the dim argument.

3 Likes

Can you explain me why this implementation for multiclass text classification doesn’t use sigmoid since expects logits?

sigmoid would convert each output to a probability in the range [0, 1].
Logits on the other side are unbound ([-inf, inf]), so you should not apply any activation function on your model outputs.

So if I have 5 output classes and 3 test instances and for example I will take the below output:
tensor([[ 0.4657, -0.7640, -1.4268, -0.5012, 1.2167],
[-0.4578, -0.5621, -0.5652, -0.4056, 0.2509],
[-0.5617, 0.8141, -0.1722, -0.1264, 0.2285]]

this means that for the first instance the right class is the 5th? for the second also the 5h and of the last instance the 2nd? (the biggest ones) ??

The mentioned classes indices might not be the right ones (this is determined by the target), but the predicted ones (with the highest probability).

If you want to get probability values, you could use F.softmax to get values in the range [0, 1].
However do not pass these values to the criterion. Use them just for debugging/printing purposes.

1 Like

So the positions I mentioned above is the “predicted classes”, right?

Thank you very much for your replies, I really appreciate it!

Yes, that’s correct. The highest logit (in this setup) gives you the predicted class.
That’s also why you can call torch.argmax directly on the logits without applying softmax, as this won’t change the predicted classes (max in logits will still be the max value after softmax).

i used torch.argmax before training the model to convert one hot encoded data ( b, class, h, w) =(16, 12, 256, 256) to (16, 256, 256). Now i can use the cross entropy loss function. Is it the right way to train a model for segmentation task?

Yes, as explained here.

1 Like