Loss Function for Multi-class with probabilities as output

I’m working on a Multi-class model where my target is a one-hot encoded vector of size C for each input sample. Since the output should be a vector of probabilities with dimension C, I’m having trouble finding what combination of output layer activation and Loss Function to use.

Based on what I’ve read so far, vanilla nn.NLLLoss and nn.CrossEntropyLoss can’t be used since the output is a label. My guess is that I would either need to tweak these loss functions to use one-hot encoded target or write my own loss. I’m somewhat confused on how to proceed form here since I don’t know how each of these options are going to impact the final model performance.


In the docs, we can see that nn.CrossEntropyLoss is the combination of nn.LogSoftmax (responsible for converting the output of your network to a probability distribution) and nn.NLLLoss. Therefore, if you want to use nn.CrossEntropyLoss, you do not need any output layer activation as it is included in the loss function. If you prefer to use nn.NLLLoss, then a nn.LogSoftmax would be a good choice.

You can convert your one-hot encoded target vector to a label in order to use either loss function as above. Here is an example for cross-entropy:

def one_hot_ce_loss(outputs, targets):
    criterion = nn.CrossEntropyLoss()
    _, labels = torch.max(targets, dim=1)
    return criterion(outputs, labels)
targets = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=torch.int32)
outputs = torch.rand(size=(4, 3), dtype=torch.float32)
loss = one_hot_ce_loss(outputs, targets)

The example supposes a batch size of 4 and the number of possible classes © as 3.

Hope this helps!

1 Like


Big thanks for commenting on my thread, your reply solved my problem!
I forgot that nn.CrossEntropyLoss already performs both nn.LogSoftmax and nn.NLLLoss, so I was adding an extra logit transformation to the output layer. Plus the snippet your provided to transform one-hot labels into scalars for the Loss function was key for me.

@edit: Added more details to reply.