I’m a bit confused by an empirical observation when using BCEWithLogitsLoss.
The scenario: I have very sparse multi-target labels for 1000 classes. For instance, each sample has on average 1 label. However, samples can have 0, 1, 2, or more labels, so it is a multi-target problem. I’m training any CNN, for instance resnet18, on a training dataset with a balanced dataset. The loss is evaluated on “full size” matrices of labels vs logits output by the CNN:
loss = BCEWithLogitsLoss( logits_2d, target_2d)
Observation: If I train with BCEWithLogitsLoss (without specifying pos_weight), the model always quickly learns to predict 0 for all classes and all samples. It then cannot escape this local minimum. This makes sense, because the labels are 99.9% 0, so the loss will be decreased by learning to output low logit values for all classes. When training, all logits quickly become <0. To avoid this problem, I can specify
This makes the 1/1000 positive labels equally as important as the 999/1000 negative labels, and the model trains properly, with logit outputs averaging around 0, and the model learns.
This “trick” of heavily weighting the loss doesn’t seem to be necessary when training in TF or jax+optax: default parameters of binary cross entropy loss train fine with sparse multi-target labels.
Question: Is there some implementation difference between BCEWithLogitsLoss and corresponding loss functions in other libraries (eg optax BCE loss or TF CategoricalCrossentropy loss ) such that the default behavior in other packages acts like
I don’t know what
reduction=losses_utils.ReductionV2.AUTO means in the TF implementation so I’m not sure if that could be involved.
Other users have reported similar behavior for PyTorch BCE Loss: