Loss functions are not supported subset dataset

Hi there,

I meet a problem that the Pytorch Loss functions such as functional.nll_loss and nn.CrossEntropyLoss are not supporting subset dataset.

I want to splitted the mnist and cifar10 dataset for specific use. For example, while loading the mnist data, I explicitly exclude the data with lable β€œ6” in my data and label set so that the training/test dataset only have number 0-5 and 7-9.

However, the loss function will fail
loss=F.nll_loss(output, target)
loss=nn.CrossEntropyLoss(output, target)
since the function expect:
input(N, C) where C = number of classes
Target: (N) where each value is 0 <= targets[i] <= C-1

In my case, this will fail since my C=9 but my target is 0-5+7-9 in stead of 0-8.

Could somebody tell me how to solve this?

I appreciate any good suggestions.

What exactly is your output?

It should be a tensor of size (N, C), where the element output[n, c] should be the probability that element n is of class c. Is this what it is? (It doesn’t sound like it, because output[n, 6] should be equal to 0 but the error is telling you you don’t have enough enough classes).

My output is in the right format, (N, C) = (64, 9). The issue is that since I have 9 classes, the loss function by default take/expects labels 0-8. However, since I only get rid of number β€œ6”, the labels left is β€œ1,2,3,4,5,7,8,9”. The β€œ9” exceed the loss expectation.

Okay. so I’m assuming that:

  • For all n: if i <= 5, output[n, i] corresponds with a probability for label i.
  • output[n, 6] corresponds to the label 7
  • output[n, 7] corresponds to the label 8
  • output[n, 8] corresponds to the label 9

If this is true you can map your targets to the thing they correspond with:

def map_labels(label):
    if label > 6:
        return label - 1
    return label
1 Like

An alternative would be to keep the labels as they are, to allow the model to predict any label from 0 to 9 and hope that the training process is sufficient to teach it never to predict 6.

Thanks Richard. I use this way to solve the issue.