Hi Egor!
This suggests to me that you are performing a single-label, multi-class
classification.
That is, each image is either in class 0, or in class 1, or in class 2, etc.,
but not in more than one class (and not in no class).
For single-label, multi-class problems, you would typically use
CrossEntropyLoss
(not BCEWithLogitsLoss
).
For CrossEntropyLoss
, you would use its weight
constructor argument.
(It plays a similar role BCEWithLogitLoss
’s pos_weight
argument.)
One would typically use the reciprocal of a class’s frequency for its
weight, thus:
weight = torch.tensor ([6595 / 614, 6595 / 947, ...])
Note, the exact values used for weight
don’t matter – they’re just rough
fudge factors to account for the class imbalance.
Also, in your case, your biggest class imbalance is about 4.5-to-1, which
isn’t too bad, so you probably don’t need to use weight
, although doing
so won’t hurt, and could help at the margins.
Best.
K. Frank