Reduction='none' leads to different computed loss

Hi everyone. I hope my question is no too stupid as I am a beginner.

I use the cross entropy loss with 512*512 images and a batch size of 3. I want to compute the reduction by myself. So I do:

criterion_none = torch.nn.CrossEntropyLoss(weight=class_weights, reduction=‘none’)
criterion_reduc = torch.nn.CrossEntropyLoss(weight=class_weights)

loss_none = criterion_none(preds, masks) # without reduction
loss_reduc = criterion(preds, masks) # with reduction

The first one gives a 3512512 (b, H, W) array of losses and the second one gives a scalar. So far, so good, that’s what does the reduction. The documentation says that by default, this reduction is done by averaging the losses over the mini batch. So I thought that doing:

loss_none = torch.mean(loss_none) will give the same result.

loss_none = loss_reduc

But I have (3 examples):

loss torch.mean() tensor(2.8491, device=‘cuda:0’, grad_fn=) loss reduction tensor(28.0402, device=‘cuda:0’, grad_fn=)
loss torch.mean() tensor(5.1461, device=‘cuda:0’, grad_fn=) loss reduction tensor(44.5213, device=‘cuda:0’, grad_fn=)
loss torch.mean() tensor(2.9905, device=‘cuda:0’, grad_fn=) loss reduction tensor(26.4063, device=‘cuda:0’, grad_fn=)

And why the “grad_fn” is not the same?

Cleary I do not understand what I am doing.

Thank you very for any help you could provide. :slight_smile:

The docs specify how the average is calculated, if a weight is specified.
This post gives you an example of a manual reduction vs. the internal one.


Thank you very much. I just tried with:

criterion_none = torch.nn.CrossEntropyLoss(reduction=‘none’)
criterion = torch.nn.CrossEntropyLoss()

And indeed, it gave me the same loss.

loss torch.mean() tensor(1.3486, device=‘cuda:0’, grad_fn=) loss reduction tensor(1.3486, device=‘cuda:0’, grad_fn=)

I read the post with the example you gave for manual calculation of loss. I am still trying to adapt this part:

loss_weighted_manual = loss_weighted_manual.sum() / weights[target].sum()

for image segmentation. I guess I’ll figure it out.

Thank you again.