BCE class weightings causing loss instability

I’m using BCEWithLogitsLoss in a multilabel classification problem in the following way:

return F.binary_cross_entropy_with_logits(logits, targets, 
                      weight = self.freq_weight, pos_weight = self.pn_weight, reduction = 'mean')

where self.freq_weight is a tensor of weights for each class to balance the relative occurrences of each class and self.pn_weight is the weight of positive examples, as described in the docs here

My intuition is that the extreme magnitude difference between relative class weights produced by the imbalance in my dataset is causing huge loss spikes during training, and preventing the loss from converging. What are some ways to deal with this?

Something you should definitely consider is to use a WeightedRandomSampler instead as an alternative to balancing out your classes without running into the exploding loss issue you described.

one reservation I have with applying a weighted random sample is that samples in my dataset often have multiple labels, and those labels which appear infrequently often also appear with labels that appear frequently. Hence, oversampling instances where the infrequent label appears also would oversample the frequent labels, which could end up being counterproductive.

Maybe there can be some more complex weighted sampling scheme where I decrease the weight of instances that don’t contain infrequent labels also, but this seems like a very convoluted path

That’s a legitimate reservation, however doing what you just described doesn’t seem that convoluted to me, since WeightedRandomSampler takes in custom weights. One way to come up with some weights would be to rank your classes by frequency (0 is least fequent and 1 is most frequent) and for each datapoint i assign a score that looks like:

score_i = product over class k of (1 - rank_k), for all classes k that appear in datapoint i

Then you use score_i to produce weights to pass to WeightedRandomSampler.

This very crude method will basically reward datapoints that have rare classes but not common classes. You can use this as a fine-tuning step at the end of model training for a few epochs (i.e. you can mostly train your model without this) with a validation set that focuses on the rare classes.