Hi, i was looking for a Weighted BCE Loss function in pytorch but couldnt find one, if such a function exists i would appriciate it if someone could provide its name.
nn.BCEWithLogitsLoss takes a
From the docs:
weight ( Tensor , optional ) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.
pos_weight ( Tensor , optional ) – a weight of positive examples. Must be a vector with length equal to the number of classes.
Is one of these weights what you are looking for?
I think pos_weight is the one I was looking for, thank you for your time
Can you explain why they say positive example not just example?
For a binary classification, you would often hear positive and negative example, which would represent the classes 1 and 0, respectively.
I think it’s the standard terminology, which is also used in e.g. confusion matrices and to calculate other metrics such as “True positive rate”, “True negative rate”, “False positive rate”, etc.
Thanks, but that was not what I was looking for. To be more clear, can you give me example to calculate weights for multilabel case. For example suppose I have example 10 examples and each example can belong to multiple label/class. In that situation what should be the process to calculate pos weights that can be used in loss function?
You could treat each occurrence of a class as the positive sample and could calculate the
pos_weight for each class.
I.e. if your complete dataset contains 100 samples in total, 90 class0 samples, and 80 class1 samples, your
pos_weight could be calculated as
negative/positive = [10/90, 20/80].
Sorry, the last bit is confusing. Can you elaborate ?
100 Samples = 90 Class0 samples + 10 Class1 samples; Or 100 samples = 20 Class0 samples + 80 Class1 samples.
But how 100 samples = 90 Class0 + 80 Class1 ?
Okay I see, if the samples have multiple targets, it makes sense. Sorry for the stupid Q !
if you only have one class you can pass in a tensor of length one into pos_weight, and the 1s of that class will be correspondingly upweighted with the 0s at weight of 1
FYI for future ppl finding this post I struggled with it for a bit lol