Hi Mehran!
Let me draw a distinction between two different problems. When I
say “multiclass” I mean more than two classes, that is, not binary.
(So, for example, classes A, B, C, and D, rather than just “yes” or
“no.”) But each sample is of exactly one of the classes.
When I say “multilabel” I mean that a given sample can have more
than one class label (e.g., this image has both a cat and a dog in
it, but not a fish).
This last use case that you talk about is what I would call multilabel:
“none or more than one class be active in the target at the same time.”
If you think about it, a multilabel problem (even for only two classes)
is really a binary classification problem, just with an additional “class”
dimension. (That is, cat – yes or no, dog – yes or no, fish – yes or no.)
BCELoss (and probably better to use, BCEWithLogitsLoss) supports
this extra class dimensions. I think this is the “cross entropy loss that
accepts two tensors of the same shape” that you are looking for.
If you don’t want to think about it terms of a binary classification
problem, you could use the equivalent, but differently named
MultiLabelSoftMarginLoss.
Here is a pytorch (version 0.3.0) session that shows that these two
loss functions are really the same. (You can get rid of the 0.3.0 oddities,
and this will run with an up-to-date pytorch. It will probably even run
with an up-to-date pytorch as-is.)
>>> import torch
>>> l1 = torch.nn.BCEWithLogitsLoss()
>>> l2 = torch.nn.MultiLabelSoftMarginLoss()
>>> pred = torch.randn ((3, 5))
>>> targ = torch.bernoulli (0.5 * torch.ones ((3, 5)))
>>> pred = torch.autograd.Variable (pred)
>>> targ = torch.autograd.Variable (targ)
>>> l1 (pred, targ)
Variable containing:
0.8563
[torch.FloatTensor of size 1]
>>> l2 (pred, targ)
Variable containing:
0.8563
[torch.FloatTensor of size 1]
You would probably want to replace the call to:
targ = torch.bernoulli (0.5 * torch.ones ((3, 5)))
with:
targ = torch.randint (1, (3, 5))
(The old 0.3.0 version didn’t have randint()
yet.)
Note that the input tensor, pred
, and target tensor, targ
, are indeed
of the same shape, in this example, (3, 5)
.
Good luck.
K. Frank