CrossEntropyLoss with weights and size_average implementation details

Hey, I’m trying to reproduce CrossEntropyLoss implementation (in order to change it later for my needs), and currently I’m not able to match the results when non-uniform weights are provided and size_average is set to True (but if weights are uniform and/or size_average is False - results match, at least their printed representations).

I tried to follow formula in pytorch reference, but it seems that either I’m missing something or the weights are applied slightly differently (or maybe I have a bug of course :slight_smile:)

Here’s my implementation:

import torch

num_classes = 5
num_samples = 3

wts = torch.abs(torch.randn(num_classes))
wts /= torch.sum(wts)

weights = torch.autograd.Variable(wts)
# weights = torch.autograd.Variable(torch.ones(num_classes))

input = torch.autograd.Variable(torch.randn(num_samples, num_classes))
target = torch.autograd.Variable(torch.LongTensor(num_samples).random_(num_classes))

# todo: check why size_average changes weight contribution
loss = torch.nn.CrossEntropyLoss(weight=weights, size_average=False)
output = loss(input, target)

correct_confidences = torch.exp(input[range(num_samples), target])
total_confidences = torch.sum(torch.exp(input), dim=1)
p_t = correct_confidences/total_confidences
CE = torch.sum(weights.index_select(0, target)*(-torch.log(p_t)))

print 'torch CE =', output
print 'manual CE =', CE

Example of output:

torch CE = Variable containing:
[torch.FloatTensor of size 1]

manual CE = Variable containing:
[torch.FloatTensor of size 1]

If I change size_average to True and replace torch.sum with torch.mean in CE computation here’s an example of what I get:

torch CE = Variable containing:
[torch.FloatTensor of size 1]

manual CE = Variable containing:
[torch.FloatTensor of size 1]

I know that dividing by total_confidences is not the best idea, but I guess it’s not the main issue here. Also since I’m a newbie at pytorch some places might look odd - feel free to point those out.

I think fot size_average=True you have to sum the weighted loss and divide by the sum of the used weights:

CE = torch.sum(weights.index_select(0, target)*(-torch.log(p_t))) / torch.sum(weights.index_select(0, target))

This should yield the same result.

Thank you, worked like a charm!
So it means that the “total weight” per batch should be normalized to 1, and the sum of weights per classes doesn’t necessarily have to sum to 1. Makes sense I guess.

Some time ago I stumbled upon a similar issue and that’s how I understand it. :wink:
Yeah, the class weights do not have to sum to 1.