How is reduction performed in `F.cross_entropy` when `weight` parameter is provided?

I try to use rescaling weight in torch.nn.functional.cross_entropy, and find the result very hard to understand. Here is the code:

>>> import torch
>>> import torch.nn.functional as F
>>> pred = torch.tensor([[[0.8054, 0.6918],
         [0.8704, 0.1927],
         [0.4033, 0.3574],
         [0.6289, 0.2227],
         [0.0425, 0.8065]],

        [[0.4279, 0.4677],
         [0.4958, 0.3767],
         [0.3411, 0.9530],
         [0.4712, 0.7330],
         [0.9196, 0.8033]]]).float()  # [2, 5, 2], 5-way classification
>>> label = torch.tensor([[2, 4], [1, 3]]).long()  # [2, 2]
>>> weight = torch.tensor([0.0886, 0.2397, 0.1851, 0.2225, 0.2640]).float()  # weight.sum() == 1
>>> loss1 = F.cross_entropy(pred, label, reduction='mean', weight=weight)
>>> loss1
tensor(1.5594)
>>> loss2 = F.cross_entropy(pred, label, reduction='none', weight=weight).sum() / label.numel()
>>> loss2
tensor(0.3553)
>>> loss3 = F.cross_entropy(pred, label, reduction='sum', weight=weight) / label.numel()
>>> loss3
tensor(0.3553)

If I understand correctly, loss1 should be the same as loss2. Obviously the reduction method is not what I suppose. So I wonder how F.cross_entropy performs mean reduction when weight is provided? Thanks!

nn.CrossEntropyLoss normalizes with the used weights, so you would have to change the loss2 calculation to:

loss2 = F.cross_entropy(pred, label, reduction='none', weight=weight).sum() / weight[label].sum()
loss2
> tensor(1.5594)

This post also describes it using another example.

Thank you for your reply! That solves my problem.