Weighing cross entropy loss

I am trying to assign different weights to tensors in my batch when computing cross entropy loss. The issue I am having is that these weights are not based on labels so I can’t seem to give them to nn.CrossEntropyLoss() as outlined at https://pytorch.org/docs/stable/nn.functional.html#cross-entropy. For an example of what I am trying to do inside each batch:

criterion = CrossEntropyLoss_modified()
fake_outputs = torch.rand((2,2)) #single batch
labels = torch.randint(0,2,(2,))
weights = torch.tensor([.7,.3]) #these values will be changed for each batch and each item in the batch will have a weight assigned to it
criterion(fake_outputs, labels, weights) # should output a weighted loss

The length of the weights will be the length of the batch, not the length.
Is this doable? Is there a better way to go about this?

I would set the reduction to none: CrossEntropyLoss(reduction='none')

This returns you a num_batch size tensor. So you can then do something like this:
(weights * criterion(output, labels)).mean()

1 Like

Perfect, just what I was looking for. Thanks!