Hello Hyo and RaLo!
Yes, from Hyo’s post, this should be understood as a imbalanced
dataset. This can be addressed with BCEWithLogitsLoss
’s
pos_weight
constructor argument.
This is not necessarily imbalanced in the sense of, say, class 7 vs.
class 23 (might be, might not be – from what Hyo has said, we don’t
know yet), but it is imbalanced in the sense of the presence, say, of
class 7 vs. the absence of class 7.
Let me give a few words of explanation:
This multi-label, 100-class classification problem should be
understood as 100 binary classification problems (run through the
same network “in parallel”). For each of the classes, say class 7, and
each sample, you make the binary prediction as to whether that class
is present in that sample.
Your class-present / class-absent binary-choice imbalance is (averaged
over classes) something like 5% class-present vs. 95% class-absent.
This is imbalanced enough that your network is likely being trained
to predict any one specific class being present with low probability.
It sounds like this is what your are seeing.
(The “standard” approach for using pos_weight
would be to calculate
for each class c
the fraction of times, f_c
, that class c
is present
in your samples (regardless of which other classes are present or
absent), and the calculate the weight w_c = (1 - f_c) / f_c
. You
then pass the one-dimensional tensor [w_0, w_1, ..., w_99]
into
BCEWithLogitsLoss
’s constructor as its pos_weight
argument.)
A second comment:
The most straightforward way to convert your network output to
0
vs. 1
predictions is to threshold the output logits against
0.0
. You are certainly allowed to convert the logits to probabilities,
and then threshold against 0.5
(or, equivalently, round), but doing
so is not necessary. More detail is given in this post:
Good luck.
K. Frank