Hi there. Is there a way for me to calculate the BCE loss for different areas of a batch with different weights? Seemed that the * weight (Tensor, optional) – a manual rescaling weight if provided it’s repeated to match input tensor shape for torch.nn.functional.binary_cross_entropy_with_logits — PyTorch 1.9.1 documentation and BCELoss — PyTorch 1.9.1 documentation only work at a batch-level, not pixel-by-pixel? Can I adopt the weight for specific pixels? Thanks.
Hi James!
The documentation is misleadingly unclear, but BCEWithLogitsLoss
(which you should use) and BCELoss
(which you shouldn’t use)
do support what you call “pixel-by-pixel” weights (as do their
functional
equivalents).
Let’s assume that input
and target
have the same shape (which
they will in any reasonable use case). Then what is required is that
the weight
argument passed into BCEWithLogitsLoss
’s constructor
have a shape that is broadcastable to the shape of input
(and hence
to the shape of target
). After broadcasting, weight
is just used to
weight each element-by-element term in the loss function on an
element-by-element basis.
(Those elements can be pixels in an image, classes in a
multi-label problem, channels, etc., or any combination thereof.
BCEWithLogitsLoss
doesn’t care how you choose to interpret
them – they’re just elements of a multi-dimensional tensor.)
If weight
has the same shape as input
, then no broadcasting is done
(or you could say that the broadcasting is trivial), and weight
will be
applied to “specific pixels” in the way that I think you’re asking for.
Best.
K. Frank
Many thanks for your detailed explanation, Frank. Is there a way to mask the cross-entropy loss with PyTorch BCE loss explicitly? I would like to mask out the low possibility pixels with weight=0 and keep others.
Hi James!
If I understand your question correctly, you can just use weights of 0.0
and 1.0
to do this.
Let’s say that the input to your model is a batch of images of shape
[nBatch, height, width]
, and that the output of your model (which
will be the input
to BCEWithLogitsLoss
) and the target
you pass
to BCEWithLogitsLoss
also have this shape.
Now let me assume that you somehow construct a
pixel_possibility_tensor
, also of the same shape, and that this
tensor tells you the “possibility” of each pixel (whatever that means
in your use case). If a “possibility” of, say, less than 0.35
is considered
“low possibility”, you would:
weight = (pixel_possibility_tensor > 0.35).float()
loss = (weight.numel() / weight.sum()) * torch.nn.BCEWithLogitsLoss (weight = weight) (input, target)
In the loss
computation, the 0.0
values in weight
will mask out
the “low possibility” pixels (and the 1.0
values will keep unchanged
the remaining pixels). The weight.numel() / weight.sum()
pre-factor is so that loss
will be the loss averaged only over the
non-masked pixels, assuming that’s what you want – without it, all
of the masked pixels would also cause a bunch of zeros to get
averaged into the loss.
Best.
K. Frank
Got it. Thank you, Frank.