Hi Mona!
Let me outline how to use BCEWithLogitsLoss
and CrossEntropyLoss
with class weights without commenting directly on your code.
When performing binary classification, I prefer using BCEWithLogitsLoss
.
Doing so more nearly “says what you mean” than CrossEntropyLoss
does, and is marginally more efficient.
It is perfectly reasonable, however, to treat binary classification as the
two-class case of multi-class classification, and use CrossEntropyLoss
.
Let’s assume that the input to your model is a batch of nBatch
samples.
In typical usage with BCEWithLogitsLoss
the final layer of your network
will be a Linear
with out_features = 1
and the output of your model
will be a batch of nBatch
logit predictions with shape [nBatch, 1]
.
If you were to convert your logit to a probability (You don’t – this is done
internally, in effect, in BCEWithLogitsLoss
.), you would get the predicted
probability of your sample being in “class-1” (the “positive” class). Your
target
would be the known probability of your sample being in “class-1”
and can be exactly 0.0
or 1.0
, in which case it is easy to think of this
probability as being a 0 / 1 label where 0.0
means “class-0” and 1.0
means “class-1.”
If your data is unbalanced, such as in your case where 16% of your training
samples are in “class-1,” you can use BCEWithLogitsLoss
’s pos_weight
constructor argument to weight the “class-1” samples more heavily in the
calculated loss. You would typically use a weight of class-0-% / class-1-%;
thus in your case you might use pos_weight = torch.tensor ([5.25])
.
If, instead, you choose to treat this as the two-class case of a multi-class
problem and use CrossEntropyLoss
, your final layer should be a Linear
with out_features = 2
, that is, separate output values for “class-0” and
“class-1”, and the output of your model would be a batch of class
predictions with shape [nBatch, 2]
(These are again logits that are, in
effect, internally converted to probabilities in CrossEntropyLoss
.) Your
target
will be (a batch of) integer class labels that take on the values 0
(for “class-0”) and 1
(for “class-1”), and will have shape [nBatch]
.
To weight your (two) classes in the loss calculation, you would use
CrossEntropyLoss
’s weight
constructor argument. Now, instead
of a single weight (such as BCEWithLogitLoss
’s pos_weight
) for
“class-1,” you will have a weight for each of your (two) classes. You
would typically weight each class proportionally to the reciprocal of
the frequency with which it appears in your training data. So in your
case, you could use weight = torch.tensor ([1.0, 5.25])
.
Note, another approach to compensating for unbalanced training data
is to sample the underrepresented class more heavily. In your case
you could build your training batches by sampling randomly from your
training data, but sample any specific “class-1” sample 5.25
times as
often as any specific “class-0” sample. Now a given batch will contain,
on average, an equal number of “class-1” samples and “class-0” samples.
(You can use this technique with both the BCEWithLogitsLoss
and the
CrossEntropyLoss
approach, and you would no longer use class weights
in the loss calculation.)
Best.
K. Frank