I try to use rescaling weight
in torch.nn.functional.cross_entropy
, and find the result very hard to understand. Here is the code:
>>> import torch
>>> import torch.nn.functional as F
>>> pred = torch.tensor([[[0.8054, 0.6918],
[0.8704, 0.1927],
[0.4033, 0.3574],
[0.6289, 0.2227],
[0.0425, 0.8065]],
[[0.4279, 0.4677],
[0.4958, 0.3767],
[0.3411, 0.9530],
[0.4712, 0.7330],
[0.9196, 0.8033]]]).float() # [2, 5, 2], 5-way classification
>>> label = torch.tensor([[2, 4], [1, 3]]).long() # [2, 2]
>>> weight = torch.tensor([0.0886, 0.2397, 0.1851, 0.2225, 0.2640]).float() # weight.sum() == 1
>>> loss1 = F.cross_entropy(pred, label, reduction='mean', weight=weight)
>>> loss1
tensor(1.5594)
>>> loss2 = F.cross_entropy(pred, label, reduction='none', weight=weight).sum() / label.numel()
>>> loss2
tensor(0.3553)
>>> loss3 = F.cross_entropy(pred, label, reduction='sum', weight=weight) / label.numel()
>>> loss3
tensor(0.3553)
If I understand correctly, loss1
should be the same as loss2
. Obviously the reduction method is not what I suppose. So I wonder how F.cross_entropy
performs mean reduction when weight
is provided? Thanks!