Cross Entropy Loss for imbalanced set (binary classification)

Dear community,

I am trying to use the weights for the binary classification problem for CrossEntropyLoss and by now I am so lost in it….

In my network I set the output size as 1 and have sigmoid activation function at the end to ensure I get values between 0 and 1. I assume it is probability in my case. If output is set as 2 (for class 0 and 1) then for some reason the sum of the columns does not equal 1. This is why I set the output size as 1.

Loss function requires two columns in the output I assume for class 0 and class 1 but which order I am nor sure. I set weights as tensor([1., 5.]) assuming that I have five times more class 0 than class 1.

Thus in the training loop I have:

outputs = nn_model(X_batch)

I think it is probability for class0. Output of shape 1.

outputs = outputs.view(X_batch.shape[1]*X_batch.shape[0],output_size)

I reshape it to one column.

one = torch.ones(outputs.shape)

I create a tensor of ones

class1 = one-outputs

I calculate probability of class1 by substracting probability of class 0 from 1.

outputs = torch.cat((outputs,two), 1)

I concatenate having 1st column prob class 0 2nd column prob class 1.

loss = criterion(outputs, labels)

I assume 1st col is class 0 2nd class1, so weights are tensor([1., 5.]) showing that I have five times more zeroes.

Is this right? if wrong then please help me to figure out what and why, PLEASE… Any comments are very much appreciated.

Hi Alice!

For a binary classification problem (two classes, “yes” and “no”)
you will prefer to use BCEWithLogitsLoss rather than
CrossEntropyLoss.

You will want your model to have a single output (so that the shape
of the output is [nBatch, 1]) and have no final Sigmoid activation
layer.

For your use case you will probably want to use the pos_weight
argument to BCEWithLogitsLoss's constructor.

In this case one would typically use:

loss_function = torch.nn.BCEWithLogitsLoss (pos_weight = torch.tensor (5.0))

so that you weight your positive sample five times more heavily than
your five-times-more-frequent negative samples.

If you have questions about how or why to use BCEWithLogitsLoss,
please follow up.

Best.

K. Frank

1 Like

Dear K. Frank,
Thank you for your time looking into it, I switched from the BCE loss to CEL because of absence in the BCE the option to add the class weights.
I however did not yet see the BCE with Logits Loss yet.
Am I right that

  • here the output and target will be of the same shape?
  • I remove sigmoid activation at the end as it is already part of the BCEWithLogitsLoss

Thank you again and best regards,
Alice

Hi Alice!

For reasons I don’t understand – I suppose that it’s just an inconsistency
or oversight – BCELoss lacks BCEWithLogitsLoss's pos_weight
argument. It’s not really an issue – BCEWithLogitsLoss should be
used anyways because of its better numerical stability.

Yes, and yes.

Best.

K. Frank

1 Like

Dear K. Frank,

Thank you! It works and I now understand what I am doing there. As what I did before was a mess.

I am very happy.

Kind regards,

Alice

Dear K. Frank,

I have one more follow up question. For my output I wish to get probabilities. That is why I had sigmoid function at the end of forward. Now with using BCEWithLogitsLoss function I deleted the sigmoid function and my output can be negative and can be bigger than 1 ;(

Do you know how to address this trouble? or did I completely miss something in out previous discussion?

Best, Alice

Hi Alice!

Yes, without the Sigmoid activation function, the output of your
model will be raw-score so-called logits. They run from -inf to
inf, and are what you want as the input to BCEWithLogitsLoss.
Sigmoid maps logits to probabilities that run from 0.0 to 1.0.

You may well want probabilities for certain purposes. In such a case,
you still want your model to output logits (no Sigmoid) that you feed
to BCEWithLogitsLoss (and then backpropagate). When you want
probabilities, just apply Sigmoid to your logits (separately from your
model, loss function, and backpropagation) to convert them to
probabilities for whatever subsequent processing you have.

A word of explanation: The logits and probabilities contain the same
information and can be transformed mathematically back and forth
into one another. For numerical reasons it’s better to pass logits
from your model to your BCEWithLogitsLoss loss function.

Best.

K. Frank

1 Like

Dear K. Frank, Thank you for your prompt and detailed answer. It is very clear now.

Can I then transform logits to probabilities by a new network model (code below)

class Sig(nn.Module):
    def __init__(self):
        super(Sig, self).__init__()
    
        self.activation = nn.Sigmoid()      
    def forward(self, x):
        h = self.activation(x) 
        return h

#create a network 
sigmoid = Sig()

and if I may cal this mini model within the training loop to calculate accuracy on the validation set?

Best regards,
Alice

Hi Alice!

Let me answer your question(s) two different ways.

You could, but doing so would be overkill. You can just call the
function (or class) version of sigmoid() directly:

my_logits = my_model (my_batch)
with torch.no_grad():
    my_probabilities = torch.nn.functional.sigmoid (my_logits)
    # or instantiate Sigmoid function obejct on the fly and call it
    # my_probabilities = torch.nn.Sigmoid() (my_logits)

There is no need to wrap sigmoid() in a “mini-model” in order to
apply it to my_logits.

You generally don’t need the actual probabilities to calculate the
accuracy of your validation-set predictions. You just need to turn
the logits into binary yes/no predictions.

Also, you may want to calculate the validation-set loss, as well.

Something like:

# assumes loss_criterion = BCEWithLogitsLoss (...)
with torch.no_grad():   # don't want or need gradients for validation calculations
    val_logits = my_model (val_batch)
    val_loss = loss_criterion (val_logits, val_targets)
    # do something with val_loss
    #
    # assumes that val_targets are exactly 0.0 and 1.0
    val_binary_preds = (val_logits > 0.0).float()
    num_correct = (val_binary_preds == val_targets).sum()
    # use num_correct to calculate average validation accuracy, etc.

Note, that a logit of zero corresponds to a probability of one half.
(sigmoid (0.0) == 0.5.) So thresholding logits against 0.0 gives
the same results as thresholding the corresponding probabilities
against 0.5, the idea being that probability > 0.5 means “yes”
(and probability <= 0.5 means “no”).

Best.

K. Frank

1 Like

Thank you so much! This is all so great.

Thank you, Frank.

Kind regards,

Alice