# Loss Function for Multi-class with probabilities as output

Hello!
I’m working on a Multi-class model where my target is a one-hot encoded vector of size C for each input sample. Since the output should be a vector of probabilities with dimension C, I’m having trouble finding what combination of output layer activation and Loss Function to use.

Based on what I’ve read so far, vanilla `nn.NLLLoss` and `nn.CrossEntropyLoss` can’t be used since the output is a label. My guess is that I would either need to tweak these loss functions to use one-hot encoded target or write my own loss. I’m somewhat confused on how to proceed form here since I don’t know how each of these options are going to impact the final model performance.

Hello,

In the docs, we can see that `nn.CrossEntropyLoss` is the combination of `nn.LogSoftmax` (responsible for converting the output of your network to a probability distribution) and `nn.NLLLoss`. Therefore, if you want to use `nn.CrossEntropyLoss`, you do not need any output layer activation as it is included in the loss function. If you prefer to use `nn.NLLLoss`, then a `nn.LogSoftmax` would be a good choice.

You can convert your one-hot encoded target vector to a label in order to use either loss function as above. Here is an example for cross-entropy:

``````def one_hot_ce_loss(outputs, targets):
criterion = nn.CrossEntropyLoss()
_, labels = torch.max(targets, dim=1)
return criterion(outputs, labels)
``````
``````targets = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=torch.int32)
outputs = torch.rand(size=(4, 3), dtype=torch.float32)
loss = one_hot_ce_loss(outputs, targets)
print(loss)
``````

The example supposes a batch size of 4 and the number of possible classes © as 3.

Hope this helps!

1 Like

I forgot that `nn.CrossEntropyLoss` already performs both `nn.LogSoftmax` and `nn.NLLLoss`, so I was adding an extra logit transformation to the output layer. Plus the snippet your provided to transform one-hot labels into scalars for the Loss function was key for me.