Weight for BCE loss

Hi there. Is there a way for me to calculate the BCE loss for different areas of a batch with different weights? Seemed that the * weight (Tensor, optional) – a manual rescaling weight if provided it’s repeated to match input tensor shape for torch.nn.functional.binary_cross_entropy_with_logits — PyTorch 1.9.1 documentation and BCELoss — PyTorch 1.9.1 documentation only work at a batch-level, not pixel-by-pixel? Can I adopt the weight for specific pixels? Thanks.

Hi James!

The documentation is misleadingly unclear, but BCEWithLogitsLoss
(which you should use) and BCELoss (which you shouldn’t use)
do support what you call “pixel-by-pixel” weights (as do their
functional equivalents).

Let’s assume that input and target have the same shape (which
they will in any reasonable use case). Then what is required is that
the weight argument passed into BCEWithLogitsLoss's constructor
have a shape that is broadcastable to the shape of input (and hence
to the shape of target). After broadcasting, weight is just used to
weight each element-by-element term in the loss function on an
element-by-element basis.

(Those elements can be pixels in an image, classes in a
multi-label problem, channels, etc., or any combination thereof.
BCEWithLogitsLoss doesn’t care how you choose to interpret
them – they’re just elements of a multi-dimensional tensor.)

If weight has the same shape as input, then no broadcasting is done
(or you could say that the broadcasting is trivial), and weight will be
applied to “specific pixels” in the way that I think you’re asking for.

Best.

K. Frank

Many thanks for your detailed explanation, Frank. Is there a way to mask the cross-entropy loss with PyTorch BCE loss explicitly? I would like to mask out the low possibility pixels with weight=0 and keep others.

Hi James!

If I understand your question correctly, you can just use weights of 0.0
and 1.0 to do this.

Let’s say that the input to your model is a batch of images of shape
[nBatch, height, width], and that the output of your model (which
will be the input to BCEWithLogitsLoss) and the target you pass
to BCEWithLogitsLoss also have this shape.

Now let me assume that you somehow construct a
pixel_possibility_tensor, also of the same shape, and that this
tensor tells you the “possibility” of each pixel (whatever that means
in your use case). If a “possibility” of, say, less than 0.35 is considered
“low possibility”, you would:

weight = (pixel_possibility_tensor > 0.35).float()
loss = (weight.numel() / weight.sum()) * torch.nn.BCEWithLogitsLoss (weight = weight) (input, target)

In the loss computation, the 0.0 values in weight will mask out
the “low possibility” pixels (and the 1.0 values will keep unchanged
the remaining pixels). The weight.numel() / weight.sum()
pre-factor is so that loss will be the loss averaged only over the
non-masked pixels, assuming that’s what you want – without it, all
of the masked pixels would also cause a bunch of zeros to get
averaged into the loss.

Best.

K. Frank

Got it. Thank you, Frank.