How to use the weight parameter for F.cross_entropy() correctly?

I’m trying to write some code like below:

x = Variable(torch.Tensor([[1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1]))
w = torch.Tensor([1.0,1.0,1.0])
F.cross_entropy(x,y,w)
w = torch.Tensor([1.0,10.0,1.0])
F.cross_entropy(x,y,w)

However, the output of cross entropy loss is always 1.4076 whatever w is. What is behind the weight parameter for F.cross_entropy()? How to use it correctly?
I’m using pytorch 0.3:drooling_face:

1 Like

You are using it correctly!
However, I think there is an explanation missing on how size_average works regarding the weight in the docs.

Have a look at the docs of NLLLoss. It states, that each loss will be divided by the sum of all corresponding class weights, if reduce=True and size_average=True.

In your case, since you just have one example, the loss will by divided by 10, which yields exactly the same result as the unweighted loss:

x = torch.Tensor([[1.0,2.0,3.0]])
y = torch.LongTensor([1])
w = torch.Tensor([1.0,1.0,1.0])
F.cross_entropy(x,y,w, size_average=False)
w = torch.Tensor([1.0,10.0,1.0])
F.cross_entropy(x,y,w, size_average=False)
F.cross_entropy(x,y,w, size_average=True)

loss = 10.0 * (-x[0, y] + torch.log(torch.sum(torch.exp(x))))
averaged_loss = loss / w[y]

If you create another sample, the loss will differ:

x = torch.Tensor([[1.0,2.0,3.0], [2.0, 1.0, 3.0]])
y = torch.LongTensor([1, 1])
w = torch.Tensor([1.0,1.0,1.0])
F.cross_entropy(x,y,w, size_average=False)
> tensor(3.8152)
w = torch.Tensor([1.0,10.0,1.0])
F.cross_entropy(x,y,w, size_average=False)
> tensor(38.1521)
F.cross_entropy(x,y,w, size_average=True)
> tensor(1.9076)

In this example 38.1521 will be divided by the sum of the corresponding weights (w[1] for each sample), thus 38.1521/20. = 1.9076.

8 Likes

Thanks a lot, it totally cleared up my confusion!

Thanks for the help.

Hi,
argument ‘size_average’ is deprecated,
instead of that now you use
print(F.cross_entropy(x,y,w, reduction=‘sum’)) instead of ‘size_average’ = False
print(F.cross_entropy(x,y,w, reduction=‘mean’) instead of ‘size_average’ = True

2 Likes

Thanks foryour help!!!