Ignore_index in the cross entropy loss

Hi. I think Pytorch calculates the cross entropy loss incorrectly while using the ignore_index option.
The problem is that currently when specifying the ignore_index (say, = k), the function just ignores the value of the target y = k (in fact, it calculates the cross entropy at k but returns 0) but it still makes full use of the logit at index k to calculate the normalization term for other indices not equal to k. I think this is not intended use for most users.

For example, in variable length sequences, people pad the sequence and use the ignore_index as the pad target index in order to avoid considering the padded values (both from inputs and targets). If there are n classes, you have to prepare (n+1) classes for the logit dimension (input of the cross entropy loss) to include the pad class and then ignore it by using the ignore_index option.

Here is some illustrative example:

# Test cross entropy loss: first create data
x = torch.log(torch.tensor([[2,3,4]],dtype=torch.float)) #  a vector of 3-class logits (one of them could be a padding class)
y1 = torch.tensor([0],dtype=torch.long)
y2 = torch.tensor([1],dtype=torch.long)
y3 = torch.tensor([2],dtype=torch.long)

# calculate the negative logsoftmax for each logit index for comparison
-torch.nn.functional.log_softmax(x,dim=1) # returns tensor([1.5041, 1.0986, 0.8109]) 

# perform logsoftmax and NLL loss at the same time (not use ignore_index yet)
print(torch.nn.functional.cross_entropy(x,y1)) # 1.5041 
print(torch.nn.functional.cross_entropy(x,y2)) # 1.0986 
print(torch.nn.functional.cross_entropy(x,y3)) # 0.8109 

# Now let's ignore the index 0 and find cross entropy loss for index 1
print(torch.nn.functional.cross_entropy(x,y2,ignore_index=0)) # get 1.0986 
# this is the same value as when not excluding the index 0 from the logit; 
# It should ignore the index since the level of the logit index, not just final target index.

# Next let's calculate the correct cross entropy loss when you actually ignore the index 0 completely from both x and y
x_ignore = x[0][1:].view(1,x.shape[-1]-1) # Now we have logits of 2 classes
# the index that is more than the ignore index is decreased by 1
y2_ignore = torch.tensor([0],dtype=torch.long)
y3_ignore = torch.tensor([1],dtype=torch.long)
# cross entropy with ignore_index 0 for the index 1 (which now becomes index 0)
print(torch.nn.functional.cross_entropy(x_ignore,y2_ignore)) # get 0.8473

In conclusion, I raise this issue in case developers may consider to revise this ignore_index option, but if the current one already follows the intended use (ignore only y not x; hence allowing to backprop through the ignored index of the logit in the normalization term of the softmax), it would be my misunderstanding of how it should work (ignore both x and y at the ignore_index).

6 Likes

I think that’s not what it means the ignore_index parameter. its description says:

Specifies a target value that is ignored and does not contribute to the input gradient. When size_average is True , the loss is averaged over non-ignored targets.

So what you specify with ignore_index = k is that the elements of the target that has value k will not contribute to the error. And if you specify size_average = True then the average won’t count those elements either.

Its intended use, or at least the way I use them, is on the cases when you add padding to your input in order to have the same length on all the instances.

6 Likes

@adrianjav Yes, the target that has index k will not contribute to the error. But the problem is that the class k at the softmax layer is not ignored when calculating the softmax for other classes (the index k still appears in the denominator of the softmax formula since Pytorch did not drop it).

For example, you have only 2 classes: class 1, and class 2 (your padding class). So when you ignore the padding class, the softmax probability of the class 1 must always be one (because there is only one class to consider) but if you try to use ignore_index option, it will not return 1 in general since it still did not eliminate the padding class from consideration (and also there is a chance that it will give higher probability on the padding class for unseen data.)

1 Like

@DKSG I think ignore_index should only be used in the extreme cases where the ignoring target dose not exist in your input.

1 Like

To add onto this topic, it would be very helpful to know exactly how ignore_index works. For instance, consider:

N = 10
criterion = nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
groundtruth = torch.rand(N, ).ge(0.5).type(torch.LongTensor)
groundtruth[7:] = -1
pred = torch.rand(N, 2)
loss = criterion(pred, groundtruth)

In such a situation, Pytorch returns a value of 0.0 for elements which are to be ignored. Running the same piece of code with N = 5000 returns weird numbers in the loss for elements to be ignored. Values such as .0646e+24 etc.

Next, how do we backprop on the above said loss?
Can I simply find all elements which aren’t -1, say, loc = groundtruth!=-1, and average by ignoring those values?

loss = torch.mean(loss[groundtruth!=-1])
loss.backward()

For some weird reason, the above mentioned situation does not work for me. The code crashes after 10 epochs or so.

1 Like

This might be a bug, as it seems the values are uninitialized.
I cannot reproduce it using your (modified) code for N = 5000.

Also note, that your criterion should get the prediction as the first argument and the target as the second.
reduction should be set as 'none' (lowercase n).

1 Like

@ptrblck Thanks for your comment, I edited mine accordingly :smile: I can confirm that this bug is occurring in torch 0.4 but cannot reproduce it in 1.0.
Could you comment on the internals of ignore_index?

Is the following code the exact. or near exact, representation of what ignore_index does?

loss = torch.mean(loss[groundtruth!=-1])
loss.backward()

Your code snippet will give the same result, as reduction='mean':

Here is a small example:

N = 10
criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=-1)
groundtruth = torch.rand(N, ).ge(0.5).type(torch.LongTensor)
groundtruth[7:] = -1
pred = torch.rand(N, 2, requires_grad=True)
loss = criterion(pred, groundtruth)
loss.backward()
print(pred.grad)

# Manual approach
pred.grad.zero_()
target = groundtruth[groundtruth!=-1]
output = pred[groundtruth!=-1]
loss_manual = -1 * F.log_softmax(output, 1).gather(1, target.unsqueeze(1))
loss_manual = loss_manual.mean()
loss_manual.backward()
print(pred.grad)
7 Likes

Hi, I’d like to ask something related to the last answer.

I’m working on a semi-supervised learning project and my dataloader generates batches with labelled (targets with values 0 to N) and unlabelled (-1) samples.

To keep it simple here, I need a CE loss that only computes the loss on the labelled samples within the batch.

Would this manual approach also work?

N = 10
groundtruth = torch.rand(N, ).ge(0.5).type(torch.LongTensor)
groundtruth[7:] = -1
pred = torch.rand(N, 2, requires_grad=True)

# ptrblck's manual approach
pred.grad.zero_()
target = groundtruth[groundtruth!=-1]
output = pred[groundtruth!=-1]
loss_manual = -1 * F.log_softmax(output, 1).gather(1, target.unsqueeze(1))
loss_manual = loss_manual.mean()
loss_manual.backward()
print(pred.grad)

# My manual approach
pred.grad.zero_()
criterion = nn.CrossEntropyLoss(reduction='mean')
target = groundtruth[groundtruth>0]
output = pred[groundtruth>0]
loss_manual = criterion(output, target)
loss_manual.backward()
print(pred.grad)

If it works that would be great, because it’s not very clear to me what is that gather doing there…

Thanks!

1 Like

Your approach currently also filters out the class0 target indices. If you use groundtruth>=0, you should get the same results as my approach.
The gather operation selects the log probabilities at the target indices in dim1.

So, for people coming from search engines, Is it a good idea to leverage ignore_index for padding tokens, especially when they are the majority in a sequence?

2 Likes