Hello,
I’m having trouble understanding behaviour of class weights in CrossEntropyLoss.
Specifically, when reduction=‘mean’. I test it like this:
input = torch.randn(5, 2, requires_grad=True)
m = nn.LogSoftmax(dim=1)
mi = m(input)
target = torch.tensor([0, 0, 1, 1, 1])
w = torch.tensor([1, 100]).float()
Now, without weights everything behaves reasonably, next 2 lines give the same result
F.nll_loss(mi, target, reduction=‘none’).mean()
F.nll_loss(mi, target, reduction=‘mean’)
However, once we introduce weight, results change
F.nll_loss(mi, target, weight=w, reduction=‘none’).mean()
F.nll_loss(mi, target, weight=w, reduction=‘mean’)
That is extremely unintuitive to me, what is the logic behind this?
The formula in documentation (https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html ) is no help, because it reuses variable name “n”, making it seem as if weights are the same for all samples. Are they?
I’ll apreciate any help
KFrank
(K. Frank)
August 17, 2020, 5:23pm
#2
Hello Awesome!
Passing weights to NLLLoss
(and CrossEntropyLoss
) gives, with
reduction = 'mean'
, a weighted average where the sum of weighted
values is then divided by the sum of the weights.
In your `reduction = ‘none’ version:
F.nll_loss(mi, target, weight=w, reduction=‘none’).mean()
by the time you get to the call to pytorch’s tensor .mean()
, the weights
are no longer available, so .mean()
cannot divide by the sum of the
weights.
An example script illustrating this can be found in this post:
Hello Mainul!
When using CrossEntropyLoss (weight = sc) with class weights
to perform the default reduction = 'mean', the average loss that
is calculated is the weighted average. That is, you should be dividing
by the sum of the weights used for the samples, rather than by the
number of samples.
The following (pytorch version 0.3.0) script illustrates this:
import torch
torch.__version__
sc = torch.FloatTensor ([0.4,0.36])
loss = torch.nn.CrossEntropyLoss (weight = sc)
input = torch.au…
Best.
K. Frank
1 Like
Hi KFrank,
Thank you, I get it now. Seems kinda weird to me, but I also understand the reasoning behind this.