Feature request: NLLLoss / CrossEntropyLoss that accepts one-hot target

I’ve been struggling with properly creating a loss function for a combination of multiclass and multilabel classification.
The fact that NLLLoss/CrossEntropyLoss only accepts categoricals and there is no equivalent for OneHot vector is handicapping.

Use case - For example with 10 classes:

  • classes 0 to 4 are exclusive (group A)
  • classes 5 and 6 are exclusive (group B)

Group A, group B and classes 7 to 9 are independant.

I would like to use, cross-entropy for group A, cross entropy for group B, binary cross-entropy for classes 7 to 9.

The target with the true labels is a one-hot-vector.
I managed to split it and format it for crossentropy and binary_cross_entropy + sigmoid but the result is quite ugly.

1/ Details on how to one-hot-encode and how to revert from one-hot-encoding.

I tried using gather but couldn’t manage to do it, so I used masked_select, is there a way with gather?

import torch
from torch import nn
from torch.autograd import Variable

target = Variable(torch.LongTensor([1, 0, 4]))
print(target)
target_onehot = Variable(torch.zeros(3, 5))
target_onehot.scatter_(1, target.view(-1,1), 1) #_ for inplace
print(target_onehot)

val = torch.arange(0,5)
print(val)
val = val.expand(3,5) #expand is torch broadcasting
print(val)

new_target=val.masked_select(target_onehot.data.byte())
print(new_target)

Output:

Variable containing:
 1
 0
 4
[torch.LongTensor of size 3]

Variable containing:
 0  1  0  0  0
 1  0  0  0  0
 0  0  0  0  1
[torch.FloatTensor of size 3x5]


 0
 1
 2
 3
 4
[torch.FloatTensor of size 5]


 0  1  2  3  4
 0  1  2  3  4
 0  1  2  3  4
[torch.FloatTensor of size 3x5]


 1
 0
 4
[torch.FloatTensor of size 3]

2/ Using that in a custom loss function:

My forward function must be defined like that for GroupA.

    def forward(self, input, target):
        # Cross-Entropy wants categorical not one-hot
        # Reverse one hot
        groupA_target = Variable(torch.arange(0,4).expand(target.size(0),4).masked_select(target[:,:4].data.byte().cpu()).long().cuda(), requires_grad = False)
        
        loss_groupA = F.cross_entropy(input[:,:4],
                                       groupA_target,
                                       self.groupA_weight,
                                       self.size_average)
        [...]

Explanation:

  • data.byte().cpu() because masked_select wants a byte tensor on CPU
  • .long().cuda() because F.cross_entropy expects a Cuda LongTensor

Ideally I would like a F.cross_entropy_OneHot function that I could use like that:

    def forward(self, input, target):
        loss_groupA = F.cross_entropy(input[:,:4],
                                       target[:,:4],
                                       self.groupA_weight,
                                       self.size_average)
        [...]

This is already possible with binary_cross_entropy + sigmoid activation.

3/ Bonus question: there is NLLLoss that matches with logsoftmax but is there a loss that matches with logsigmoid?

8 Likes

I think you need to port https://github.com/torch/nn/blob/master/MultiLabelSoftMarginCriterion.lua to pytorch. You just need to write the forward pass in a function, and the rest should be handled by autograd

@fmassa, Are you replying to point 3?

MultiLabelSoftMarginLoss is already there and I’m aware of it. It’s implemented in terms of binary_cross_entropy + sigmoid.

As far as I know doing logsigmoid (or logsoftmax) in one go is:

  • faster than doing sigmoid then taking the log (log is done within binary_cross_entropy)
  • numerically stabler

Hence my question if there is a corresponding loss function to logsigmoid as there is for logsoftmax (NLLLoss)

i try this:

# weather_labels: partly_cloudy, cloudy, clear, hazy
# other_labels : other like primary, ....

def criterion(weather_logits, weather_labels, other_logits, other_labels):

    w0= 4/17  #0.33333 ##<todo> best weight?
    w1=13/17  #0.66666
    loss0 = nn.CrossEntropyLoss       ()(weather_logits, Variable(weather_labels))
    loss1 = nn.MultiLabelSoftMarginLoss()(other_logits,   Variable(other_labels  ))
    loss  = w0*loss0+w1*loss1

    return loss

#in the model
      ...
class MyNet(nn.Module):
     def __init__(self, in_channels=3, num_weather_classes=4, num_other_classes=13):
        super(MyNet, self).__init__()
        ...
        self.logit0 = nn.Linear(512, num_weather_classes)
        self.logit1 = nn.Linear(512, num_other_classes)

    def forward(self, x):
        ...
        out0 = self.logit0(out)
        out1 = self.logit1(out) 

        return out0, out1




#in the training lopp 
      ...
      for it, (images, labels, indices) in enumerate(train_loader, 0):
 
            optimizer.zero_grad()

            # weather_labels : one hot convert to index
            _, weather_labels = torch.max(labels[:,:4])
 
            #other_labels : one hot
            other_labels = labels[:,4:]

            weather_logits,  other_logits = net(Variable(images.cuda()))
            loss  = criterion(weather_logits, weather_labels.cuda(), other_logits, other_labels.cuda())
            loss.backward()
            optimizer.step()

1 Like

I strongly appeal to support this Feature request: NLLLoss / CrossEntropyLoss
that accepts one-hot vector target! Don’t know what’s the situation now?!

From current docs of CrossEntropyLoss / NLLLoss
(http://pytorch.org/docs/master/nn.html?highlight=nllloss#torch.nn.NLLLoss),
it’s hard to understand or interpret CrossEntropyLoss or NLLLoss as a loss function!
From mathematical point of view, a loss function loss(x,y) = 0 if and only if x=y.
For example, torch.nn.MSELoss satisfies such condition. It’s natural to introduce
one-hot vector target in CrossEntropyLoss such that it meets the above condition
(with help of x*log(x) -> 0 as x -> 0). In addition, one-hot vector is a special discrete
probability distribution. Tensorfollow has the one-hot vector in its loss function implement.
Torch should have this feature too!

5 Likes