I would recommend (and would in general, not just for your case)
not computing pos_weight on a per-batch basis, but, instead,
use the pos_weight for your whole training set, or for a
representative sample of your training set.
In any event, your individual batches should be representative
of your whole training set. They differ from one another randomly,
of course, but shouldn’t differ from one another systematically.
For example, you wouldn’t want to train on several batches of
mostly positive samples, and then train on batches of mostly
negative samples.
I’m using the DataLoader, so I don’t really have all of the samples loaded at once - only batch per batch. And those batches are randomly loaded, so isn’t it the same to have my pos_weight be per batch and per dataset?
Pre-process your training set. Count the number of positive vs.
negative samples in your training set – or if this is too bulky or
expensive, in a representative sample of your training set – and
use these counts to calculate a single value of pos_weight that
you use for all of your batches.
Well, for large enough batches, this will be approximately true.
But, as you’ve discovered, with your batch size this can fail
because a batch can can contain no positive sample, so you
get a zero-divide when you try to calculate a per-batch pos_weight for such a batch.
Part of the reason that some of the classes will not have any positive samples is because I am pad_sequence my data, so a bunch of 0 are added to the end. So I’ll always have some elements that have no positive case.
And it’s not that the batch doesn’t contain ANY positive samples. It doesn’t contain positive samples for a specific position of the target vector
From this I understand that you are performing a multi-label,
multi-class classification problem, where your “classes” are
the positions in your sequence.
If it is really true that some positions (that is, some classes)
are never positive across your entire training set, then don’t
reweight such classes (no zero-divides). Simply let your network
learn that these classes are never positive, and your network
will always (correctly) predict negative for them.