Passing the weights to CrossEntropyLoss correctly

how are we getting the weights? Like if the true distribution while testing is 10:1 (for sample frequency) and in training it is 1:1 (sample frequency). What should the weights be in that case?

1 Like

You shouldn’t have such variations in your distributions between train and test anyhow.

Hi,
I noticed that the KLDivLoss class does not have the parameter weight, so how can I passing the weight to it?

Hi, shouldn’t your weights/class_weights vector be normalised (i.e. the total adds up to 1) before passing it to the CrossEntropyLoss function?

I think we don’t need to normalize the weights of classes, using invert class sampling count is okay.

1 Like

Small addition to @MariosOreo’s answer:
if your loss function uses reduction='mean', the loss will be normalized by the sum of the corresponding weights for each element. If you are using reduction='none', you would have to take care of the normalization yourself.
Here is a small example:

x = torch.randn(10, 5)
target = torch.randint(0, 5, (10,))

weights = torch.tensor([1., 2., 3., 4., 5.])
criterion_weighted = nn.CrossEntropyLoss(weight=weights)
loss_weighted = criterion_weighted(x, target)

criterion_weighted_manual = nn.CrossEntropyLoss(weight=weights, reduction='none')
loss_weighted_manual = criterion_weighted_manual(x, target)
loss_weighted_manual = loss_weighted_manual.sum() / weights[target].sum()

print(loss_weighted == loss_weighted_manual)
26 Likes

This is actually exactly related to my problem; working in a continual learning setting, the number of classes changes through time; If I want to use weight balancing, but the weights vector changes through time, what do I do here? Do I have to reinitialize my loss for every new task to use weight balancing?

1 Like

@ptrblck will there be gradients backpropp-ed through these weights if we constructed the weights from the logits themselves (as in focal losses for object detection (https://arxiv.org/pdf/1708.02002.pdf))

No, as the derivative for weights is not implemented.
You could check it by passing a weights tensor using requires_grad=True, which will raise the exception:

RuntimeError: the derivative for 'weight' is not implemented

However, have a look at this implementation of the focal loss.

I see, thanks.
Alternatively, even if i construct this weights matrix and matrix multiply it with the loss matrix i get back when using the loss with reduction = ‘none’, it should be equivalent right?

I think so, but would need to see the code to be sure. :wink:
However, I’m not sure, how you are constructing the weights and if they should also be trainable or not.

In my case the final focal loss computation looks like the code below (focal loss is supposed to backprop the gradients even through the weights as i understand, since none of the repos i referenced including the one mentioned above, calls detach() on these weights for which backward() is well defined):

    alpha_factor = torch.ones(targets.shape).cuda() * alpha
    alpha_factor = torch.where(targets_eq_1, 1. - alpha_factor, alpha_factor)             
    focal_weight = torch.where(targets_eq_1, 1. - classification, classification)
    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
    cls_loss = F.binary_cross_entropy_with_logits(classification, targets, size_average=False, reduction = 'none')
    final_loss = focal_weight*cls_loss
1 Like

You might be right, but wouldn’t you recreate focal_weight in each iteration (based on the targets)?
Where would the gradient in focal_weight be used in this case?

Yes, this focal weight would change every iteration. I myself wasn’t entirely sure if focal weights should be backpropagated into, but it seems that way from open-source implementations (the original paper doesn’t have a clarification https://arxiv.org/abs/1708.02002)
In this case, the gradients would flow from the focal weights into/through the ‘classification’ vector (which are the logits) I guess.

I’m missing something. Loss (Error on rare class) same as Loss (Error on common class)?

weight_inbalance=torch.tensor([1000.0,1.0]) # very rare first class
loss_inbalance = torch.nn.CrossEntropyLoss(weight=weight_inbalance,reduction=‘mean’)

  • batch of 2 samples, toggle only first sample
    inbal_err_on_rare = loss_inbalance(torch.tensor([[-1.,-1],[-1,1]]), torch.tensor([0,1]))
    inbal_err_on_com = loss_inbalance(torch.tensor([[1.,1],[-1,1]]), torch.tensor([0,1]))
    print(inbal_err_on_rare, inbal_err_on_com)
    tensor(0.6926) tensor(0.6926)

The same experiment on a ballance class yeilds the same equivalence, just offset in value.

weight_balance=torch.tensor([1.0,1.0])
loss_balance = torch.nn.CrossEntropyLoss(weight=weight_balance,reduction=‘mean’)
bal_err_on_not_rare = loss_balance(torch.tensor([[-1.,-1],[-1,1]]), torch.tensor([0,1]))
bal_err_on_com = loss_balance(torch.tensor([[1.,1],[-1,1]]), torch.tensor([0,1]))
print(bal_err_on_not_rare, bal_err_on_com)
tensor(0.4100) tensor(0.4100)

I generated the contours, and they look the same, just shifted.

weight_inbalance=torch.tensor([1000.0,1.0]) # very rare first class
weight_balance=torch.tensor([1.0,1.0])
loss_inbalance = torch.nn.CrossEntropyLoss(weight=weight_inbalance,reduction=‘mean’)
loss_balance = torch.nn.CrossEntropyLoss(weight=weight_balance,reduction=‘mean’)
X = np.arange(-1., 1.1, 0.1)
Y = np.arange(-1., 1.1, 0.1)
Z_inbalance = np.zeros((X.shape[0],Y.shape[0]))
Z_balance = np.zeros((X.shape[0],Y.shape[0]))
#print(X.shape, Y.shape, Z.shape)

i = -1
for x in X:
i+=1
j=-1
for y in Y:
j+=1
batch=torch.tensor([[x, y], [-1, 1]], dtype=torch.float32)

    X[i], Y[j], Z_inbalance[i, j] =x,y,loss_inbalance(batch, torch.tensor([0,1]))
    X[i], Y[j], Z_balance[i, j] =x,y,loss_balance(batch, torch.tensor([0,1]))

fig, ax = plt.subplots(2)

CS0 = ax[0].contour(Y,X,Z_inbalance,cmap=cm.coolwarm)
ax[0].clabel(CS0, inline=1, fontsize=14)
#ax[0].set_title(‘Inbalanced Classes’)
ax[0].set_ylabel(‘true class is rare’)
#ax[0].set_xlabel(“not selected class is Common”)

CS1 = ax[1].contour(Y,X,Z_balance,cmap=cm.coolwarm)
ax[1].clabel(CS1, inline=1, fontsize=14)
#ax[1].set_title(‘Balanced Classes’)
ax[1].set_ylabel(‘true class is common’)
ax[1].set_xlabel(“not selected class is Common”)

So either I am missing the point of weights, or made a newbie error. Either way, I would appreciate some comments.

image

Note that in your first example both output tensors predict class1 only, so it’s expected to get the same loss value.
You can check it via:

torch.argmax(torch.tensor([[-1.,-1],[-1,1]]), 1)
> tensor([1, 1])

torch.argmax(torch.tensor([[1.,1],[-1,1]]), 1)
> tensor([1, 1])

I haven’t looked into the second experiment as the formatting is unclear, but you could rerun it with another reduction (e.g. reduction='sum'), since the default reduction='mean' would normalize the loss with the applied weights.

Thank you, again.
Not following how the argmax relates CrossEntropyLoss. Please elaborate.

You are correct, somehow the “mean” nullifies the weight bias.

Experiment 1 shows an error on any logit when classes are balanced results in the same loss.
Experiment 2 shows if the rare class is the true class, the losses are amplified when any error occurs.

Experiment 1:

  • balanced 2 classes,
  • batch of 2 samples,
  • introduce a single logit error and observe loss for true class = 0, 0 and 1, 1

loss = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.,1.]),reduction=‘sum’)
print(“2 balanced class, reduction=sum”)

true_class_index = torch.tensor([0,0]) # so perfect prediction is [1,-1], [1,-1]
incorrect_sample_0_class_0 = loss(torch.tensor([[-1., -1],[1.,-1]]), true_class_index)
incorrect_sample_0_class_1 = loss(torch.tensor([[1., 1],[1.,-1]]), true_class_index)
incorrect_sample_1_class_0 = loss(torch.tensor([[1., -1],[-1.,-1]]), true_class_index)
incorrect_sample_1_class_1 = loss(torch.tensor([[1., -1],[1.,1]]), true_class_index)
print(f"all have one logit wrong {incorrect_sample_0_class_0:0.2f}, {incorrect_sample_0_class_1:0.2f}, {incorrect_sample_1_class_0:0.2f}, {incorrect_sample_1_class_1:0.2f}")

  • all have one logit wrong 0.82, 0.82, 0.82, 0.82

true_class_index = torch.tensor([1,1]) # so perfect prediction is [1,-1], [1,-1]
incorrect_sample_0_class_0 = loss(torch.tensor([[1., 1],[-1.,1]]), true_class_index)
incorrect_sample_0_class_1 = loss(torch.tensor([[-1., -1],[-1.,1]]), true_class_index)
incorrect_sample_1_class_0 = loss(torch.tensor([[-1., 1],[1.,1]]), true_class_index)
incorrect_sample_1_class_1 = loss(torch.tensor([[-1., 1],[-1.,-1]]), true_class_index)
print(f"all have one logit wrong {incorrect_sample_0_class_0:0.2f}, {incorrect_sample_0_class_1:0.2f}, {incorrect_sample_1_class_0:0.2f}, {incorrect_sample_1_class_1:0.2f}")

  • all have one logit wrong 0.82, 0.82, 0.82, 0.82

Experiment 2:

  • unbalanced 2 classes,
  • batch of 2 samples,
  • introduce a single logit error and observe loss for true class = 0, 0 and 1, 1

loss = torch.nn.CrossEntropyLoss(weight=torch.tensor([10000.,1.]),reduction=‘sum’)
print(“2 unbalanced class, reduction=sum”)

true_class_index = torch.tensor([0,0]) # so perfect prediction is [1,-1], [1,-1]
incorrect_sample_0_class_0 = loss(torch.tensor([[-1., -1],[1.,-1]]), true_class_index)
incorrect_sample_0_class_1 = loss(torch.tensor([[1., 1],[1.,-1]]), true_class_index)
incorrect_sample_1_class_0 = loss(torch.tensor([[1., -1],[-1.,-1]]), true_class_index)
incorrect_sample_1_class_1 = loss(torch.tensor([[1., -1],[1.,1]]), true_class_index)
print(f"all have one logit wrong {incorrect_sample_0_class_0:0.2f}, {incorrect_sample_0_class_1:0.2f}, {incorrect_sample_1_class_0:0.2f}, {incorrect_sample_1_class_1:0.2f}")
> - all have one logit wrong 8200.75, 8200.75, 8200.75, 8200.75

true_class_index = torch.tensor([1,1]) # so perfect prediction is [1,-1], [1,-1]
incorrect_sample_0_class_0 = loss(torch.tensor([[1., 1],[-1.,1]]), true_class_index)
incorrect_sample_0_class_1 = loss(torch.tensor([[-1., -1],[-1.,1]]), true_class_index)
incorrect_sample_1_class_0 = loss(torch.tensor([[-1., 1],[1.,1]]), true_class_index)
incorrect_sample_1_class_1 = loss(torch.tensor([[-1., 1],[-1.,-1]]), true_class_index)
print(f"all have one logit wrong {incorrect_sample_0_class_0:0.2f}, {incorrect_sample_0_class_1:0.2f}, {incorrect_sample_1_class_0:0.2f}, {incorrect_sample_1_class_1:0.2f}")

  • all have one logit wrong 0.82, 0.82, 0.82, 0.82

btw, is there a way to post code snippets and not loose indentation on this site?

An excellent summary for CrossEntropy https://youtu.be/ErfnhcEV1O8

Is it generally good practice to make weights linearly related to the number of related label pixels? ie. If you have 10 of class 1, 10 of class 2, and 20 of class 3, your weights would be [1,1,2]? I am facing a segmentation problem where there are many orders of magnitude difference between the each class and not sure what loss function/weights to handle this with.

2 Likes

Usually you increase the weight for minority classes, so that their loss also increases and forces the model to learn these samples. This could be done by e.g. the inverse class count (class frequency).

6 Likes