yes, that should solve it
Its not working. None of my class2 samples are classified correctly.
I made a loss function for this:
def BCELoss_ClassWeights(input, target, class_weights):
# input (n, d)
# target (n, d)
# class_weights (1, d)
input = torch.clamp(input,min=1e-7,max=1-1e-7)
bce = - target * torch.log(input) - (1 - target) * torch.log(1 - input)
weighted_bce = (bce * class_weights).sum(axis=1) / class_weights.sum(axis=1)[0]
final_reduced_over_batch = weighted_bce.mean(axis=0)
return final_reduced_over_batch
torch.clamp(input, 1e-9, 1-1e-9)
won’t work with torch.float32
input. It will simply clamp to 0
or 1
.
Using 1e-7 might be more appropriate.
>>> torch.clamp(torch.tensor(2.), max=1-1e-9) == 1
tensor(True)
>>> torch.clamp(torch.tensor(2.), max=1-1e-7) == 1
tensor(False)
Hi @miguelvr,
First of all thank you for having opened this topic.
I have a question for you: what if I initialize the criterion once for every batch with a proper weight torch tensor (that is batch dependent)?
Would it not properly work?
Thank you in advcanced