Cross entropy loss ignore index


I am now doing a project about translation and using torch.nn.CrossEntropyLoss with torch version == 1.11.0. CrossEntropyLoss — PyTorch 1.11.0 documentation

Refering to the document, I can use logits for the target instead of class indices to get the loss, so that the target/label shape will be (batchsize*sentencelength,numberofclass) in my case.

However, the document says that I cannot use ignore_index in this case. Then, what should I do if I don’t want to calculate the loss for token?

Btw, I am using generated sentences data to train the mode, so it’s like a bi-level optimization end-to-end training, and the generated sentences are one-hot encoded and I cannot use argmax here, because I need to take derivative to the final loss with respect the generator model.

Any help will be appreciated!

NIT: you can use probabilities as the target now, not logits.

If “smooth” targets are passed there is no defined class index anymore, which is why ignore_index won’t work in this case.
Instead, you could use the unreduced loss, filter out the unwanted “classes”, and reduce it afterwards as seen here:

# class indices
criterion = nn.CrossEntropyLoss(ignore_index=0)
input = torch.randn(3, 2, requires_grad=True)
target = torch.tensor([0, 1, 0])

loss = criterion(input, target)
# > tensor(1.2963, grad_fn=<NllLossBackward0>)

# class probabilities
criterion = nn.CrossEntropyLoss(reduction='none')
target = torch.nn.functional.one_hot(target).float()

loss = criterion(input, target)
loss = loss[target[:, 0] != 1].mean()
# > tensor(1.2963, grad_fn=<MeanBackward0>)
1 Like

Hi ptrblck!
Thank you so much for your answer, helps a lot!
I got another short question.
I am doing a bi-level optimization problem (ie. train two models). And I got some CUDA oom problem.
Usually, how we train the model is like:

loss = CTG_loss(input_w, input_w_attn, output_w, output_w_attn, attn_idx, A, w_model)

To save the memory of CUDA can i do this? In this way, after training the first model, some memory will be released.

loss_w = CTG_loss(input_w, input_w_attn, output_w, output_w_attn, attn_idx, A, w_model)

I don’t think your second approach would avoid the OOM issue and would force PyTorch to allocate new memory in the next iteration. To avoid keeping the gradients around you could use optimizer.zero_grad(set_to_none=True) which could help.

1 Like