Weighted loss for unbalanced datasets

I am trying to use the loss weighing for my unbalanced data set. During my search on how to provide the weight factors to a loss function I have come upon this discrepancy:
Doc for NLLlos:

  • weight ( Tensor , optional ) – a manual rescaling weight given to each class. If given, it has to be a Tensor of size C. Otherwise, it is treated as if having all ones.

Doc for BCEWithLogitsLoss:

  • weight ( Tensor , optional ) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size “nbatch”.

I am using BCEWithLogitsLoss in my model and I tried providing the loss in the form of one dimensional tensor where each entry represents the weight factor for corresponding bach sample. This gave me an error of the form:

The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1
type or paste code here

When my batch size was 4. Note the tensor of weight factors for a 2 class problem was of the following form:

tensor([0.2083, 0.7917, 0.7917, 0.2083])

Then I tried providing the weight factor as specified in the NLLloss docs, where each entry in a tensor is weight factor for a specific class. And if I understand correctly, in this case labels need to be in range [0,1,2,…,n] where n is the number of classes and then class weight factors are mapped to the labels by position, for example: if class is 2 weight factor is weight_tensor[2] (weight with index 2 in weight tensor) . When I implemented it like this the code worked fine. So I was thinking maybe the docs for BCEWithLogits need to be changed, since it seem the later implementation works correctly.

If this is not the case I don’t understand in what form exactly do the class weights need to be passed to the BCEWithLogits loss.

Thank you!

How are you creating the weight tensor?
Dynamically based on the current batch or did you somehow precompute it?
Could you post a small code snippet showing this error?

PS: It might be related to this issue.

Yes the weights are precomputed, based on class frequencies. But yes the issue you pointed out is same as mine. Your discussion on the issue also clarifies to me what is going on, so thank you really much for the help :slight_smile: .