I have a very imbalanced dataset (86% 0, 14% 1) and I have been trying to use the BCEWithLogitsLoss function, specifically the weight parameter. I have looked at the other topics covering this issue, but I haven’t found an answer that settles this for me.
Right now I have the weights shaped according to the batch size as noted in the documentation. However, beyond that I don’t understand how to correctly initialize them to ensure my objective is met.
weight = torch.ones((batch_size),1) * 6.65
weight = weight.to(device, dtype=torch.float)
criterion = nn.BCEWithLogitsLoss(weight=weight)
I got ‘6.65’ because of the n_total/n_positive. I know that just spamming this value for each sample in the batch is wrong, but the notation doesn’t give a clear example on what is the right way to approach this issue. Each batch has a different order and outside of writing a function that mirrors each batch, I don’t understand how this was intended to be used, given that there aren’t any solid examples to follow. I just want to correctly utilize the weights attribute to see if I get a better representation.
Any help is appreciated.