Ignore_index in the cross entropy loss

Hi, I’d like to ask something related to the last answer.

I’m working on a semi-supervised learning project and my dataloader generates batches with labelled (targets with values 0 to N) and unlabelled (-1) samples.

To keep it simple here, I need a CE loss that only computes the loss on the labelled samples within the batch.

Would this manual approach also work?

N = 10
groundtruth = torch.rand(N, ).ge(0.5).type(torch.LongTensor)
groundtruth[7:] = -1
pred = torch.rand(N, 2, requires_grad=True)

# ptrblck's manual approach
pred.grad.zero_()
target = groundtruth[groundtruth!=-1]
output = pred[groundtruth!=-1]
loss_manual = -1 * F.log_softmax(output, 1).gather(1, target.unsqueeze(1))
loss_manual = loss_manual.mean()
loss_manual.backward()
print(pred.grad)

# My manual approach
pred.grad.zero_()
criterion = nn.CrossEntropyLoss(reduction='mean')
target = groundtruth[groundtruth>0]
output = pred[groundtruth>0]
loss_manual = criterion(output, target)
loss_manual.backward()
print(pred.grad)

If it works that would be great, because it’s not very clear to me what is that gather doing there…

Thanks!

1 Like