Looking for a cross entropy loss that accepts two tensors of the same shape

The cross-entropy loss function in torch.nn.CrossEntropyLoss takes in inputs of shape (N, C) and targets of shape (N). This means that targets are one integer per sample showing the index that needs to be selected by the trained model.

In my case, I’ve already got my target formatted as a one-hot-vector. And also, the output of my model has already gone through a softmax function. Now, I’m looking for a loss function to calculate the loss between my model’s predictions and the gold targets. Is there already such a loss function in Pytorch? Or do I need to develop my own? If I have to develop my own, could you please point me in the right direction?

How about this:

def cross_entropy(output: torch.Tensor, target: torch.Tensor):
    return -1.0 * (output.log() * target).mean()

You could easily transform the one-hot encoded tensors to class index tensors using target = torch.argmax(one_hot, 1) and just remove the softmax at the end of your model.
Generally I wouldn’t recommend to rewrite PyTorch function, if you don’t have a valid use case for it, as e.g. internally the log-sum-exp trick is used to stabilize the calculations numerically.

Anyway, I think the mean in your formula would calculate the mean over all classes, which should most likely yield a lower loss than the corresponding one using nn.CrossEntropyLoss.

Thanks. As a matter of fact, the reason why I turned to this approach is a use case which is not supported by the cross entropy currently implemented. And I’ll be really surprised if I’m the first to complain why Pytorch developers have decided to implement a special case of cross entropy instead of the real one.

Anyways, the current implementation of cross entropy in Pytorch does not support multi class targets. These are the problems where there could be none or more than one class be active in the target at the same time. I guess I’ll go with my own function then.

1 Like

Hi Mehran!

Let me draw a distinction between two different problems. When I
say “multiclass” I mean more than two classes, that is, not binary.
(So, for example, classes A, B, C, and D, rather than just “yes” or
“no.”) But each sample is of exactly one of the classes.

When I say “multilabel” I mean that a given sample can have more
than one class label (e.g., this image has both a cat and a dog in
it, but not a fish).

This last use case that you talk about is what I would call multilabel:
“none or more than one class be active in the target at the same time.”

If you think about it, a multilabel problem (even for only two classes)
is really a binary classification problem, just with an additional “class”
dimension. (That is, cat – yes or no, dog – yes or no, fish – yes or no.)

BCELoss (and probably better to use, BCEWithLogitsLoss) supports
this extra class dimensions. I think this is the “cross entropy loss that
accepts two tensors of the same shape” that you are looking for.

If you don’t want to think about it terms of a binary classification
problem, you could use the equivalent, but differently named

Here is a pytorch (version 0.3.0) session that shows that these two
loss functions are really the same. (You can get rid of the 0.3.0 oddities,
and this will run with an up-to-date pytorch. It will probably even run
with an up-to-date pytorch as-is.)

>>> import torch
>>> l1 = torch.nn.BCEWithLogitsLoss()
>>> l2 = torch.nn.MultiLabelSoftMarginLoss()
>>> pred = torch.randn ((3, 5))
>>> targ = torch.bernoulli (0.5 * torch.ones ((3, 5)))
>>> pred = torch.autograd.Variable (pred)
>>> targ = torch.autograd.Variable (targ)
>>> l1 (pred, targ)
Variable containing:
[torch.FloatTensor of size 1]

>>> l2 (pred, targ)
Variable containing:
[torch.FloatTensor of size 1]

You would probably want to replace the call to:

targ = torch.bernoulli (0.5 * torch.ones ((3, 5)))


targ = torch.randint (1, (3, 5))

(The old 0.3.0 version didn’t have randint() yet.)

Note that the input tensor, pred, and target tensor, targ, are indeed
of the same shape, in this example, (3, 5).

Good luck.

K. Frank

1 Like

Thanks K. Frank,

That was very informative. You are right, using the right terminology is crucial. Thanks for correcting me.

I guess I’ll go with MultiLabelSoftMarginLoss even though it seems just like BCEWithLogitsLoss (I’m not sure what are their differences at this point!).

Thanks again,