With BCEWithLogitsLoss, how do I deal with pos_weight if there are no positive cases in a batch?

My loss function is:

def loss_fn(inp, target):
    zeros_sum = (target == 0).sum(dim = 0).float()
    one_sum = (target == 1).sum(dim = 0).float()

    pos_weight = zeros_sum / one_sum 
    loss_fn = torch.nn.BCEWithLogitsLoss(reduction='mean', pos_weight=pos_weight)
    
    loss = loss_fn(inp, target)

    return loss

However, in certain batches, there are no one_sum examples for a particular class. How do I deal with pos_weight in those cases?

Hi Shamoon!

I would recommend (and would in general, not just for your case)
not computing pos_weight on a per-batch basis, but, instead,
use the pos_weight for your whole training set, or for a
representative sample of your training set.

In any event, your individual batches should be representative
of your whole training set. They differ from one another randomly,
of course, but shouldn’t differ from one another systematically.
For example, you wouldn’t want to train on several batches of
mostly positive samples, and then train on batches of mostly
negative samples.

Best.

K. Frank

I’m using the DataLoader, so I don’t really have all of the samples loaded at once - only batch per batch. And those batches are randomly loaded, so isn’t it the same to have my pos_weight be per batch and per dataset?

Hello Shamoon!

Pre-process your training set. Count the number of positive vs.
negative samples in your training set – or if this is too bulky or
expensive, in a representative sample of your training set – and
use these counts to calculate a single value of pos_weight that
you use for all of your batches.

Well, for large enough batches, this will be approximately true.
But, as you’ve discovered, with your batch size this can fail
because a batch can can contain no positive sample, so you
get a zero-divide when you try to calculate a per-batch
pos_weight for such a batch.

Best.

K. Frank

Part of the reason that some of the classes will not have any positive samples is because I am pad_sequence my data, so a bunch of 0 are added to the end. So I’ll always have some elements that have no positive case.

And it’s not that the batch doesn’t contain ANY positive samples. It doesn’t contain positive samples for a specific position of the target vector

Hi Shamoon!

From this I understand that you are performing a multi-label,
multi-class
classification problem, where your “classes” are
the positions in your sequence.

If it is really true that some positions (that is, some classes)
are never positive across your entire training set, then don’t
reweight such classes (no zero-divides). Simply let your network
learn that these classes are never positive, and your network
will always (correctly) predict negative for them.

Good luck!

K. Frank

So then are you proposing that I don’t use pos_weight at all? But I do have a severe imbalance since most of my value are negative most of the time.