"pos_weight" and "weight" parameters in BCEWithLogitsLoss

Hi Ali!

Let me answer a couple of your specific questions first and then explain
how I look at it.

In this case the current (1.9.0) documentation for BCEWithLogitsLoss
is wrong (or at least misleading). Quoting:

weight (Tensor, optional) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.

On the contrary, weight can have other shapes – for example, it can
have the same shape as input. (As near as I can tell, the precise
requirement is that the shape of weight and input need to be
broadcastable.)

The documentation for the functional version,
binary_cross_entropy_with_logits(), comes closer to being correct:

weight (Tensor, optional) – a manual rescaling weight if provided it’s repeated to match input tensor shape

I think that what you are calling the “class weights tensor” is the
weight tensor. (Any imagined difference would arise from how
you interpret them based on your use case.)

Please note that BCEWithLogitsLoss takes four tensor arguments:
weight, pos_weight, input, and target. The first two are passed
in when BCEWithLogitsLoss’s constructor is called to instantiate a
loss-function object, and the second two are passed in when the
resulting loss-function object is called. There is no separate “class
weights” argument.

Now to explain my understanding:

As far as I can tell, weight, pos_weight, input, and target need
only be broadcastable to one another. To simplify the discussion, let’s
assume that they are all of the same shape.

BCEWithLogitsLoss doesn’t make any distinction, for example,
between labels / predictions for a specific class and specific samples
within a batch. It simply applies the BCEWithLogitsLoss formula
on an element-wise basis, including the weight and pos_weight
weightings, also on an element-wise basis. This produces a tensor
of element-wise loss values of the same shape as input (and the
other arguments) that is then reduced (or not) according to the value
of reduction.

Consider the following:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> nBatch = 2
>>> nClass = 3
>>> nSomethingElse = 5
>>> weight = torch.rand (nBatch, nClass, nSomethingElse)
>>> pos_weight = torch.rand (nBatch, nClass, nSomethingElse)
>>> input = torch.randn (nBatch, nClass, nSomethingElse)
>>> target = torch.rand (nBatch, nClass, nSomethingElse)
>>> torch.nn.BCEWithLogitsLoss (weight = weight, pos_weight = pos_weight) (input, target)
tensor(0.3176)
>>> torch.nn.BCEWithLogitsLoss (weight = weight.flatten(), pos_weight = pos_weight.flatten()) (input.flatten(), target.flatten())
tensor(0.3176)

BCEWithLogitsLoss doesn’t care about any particular dimensions or
assign them meanings like “batch” or “class” or “height” or “width” – it
simply performs the (weighted) element-wise loss computation and
then reduces the result.

Where does the notion of a “class weights tensor” come from?
Consider a use case where we are performing a multi-label,
nClass-class loss calculation for a batch of nBatch samples:

Let input and target both have shape [nBatch, nClass]. If weight
has shape [nClass] it will be broadcast to match the shape of input,
and, indeed, the elements of weight will be class weights in the loss
calculation – but not because BCEWithLogitsLoss knows or cares
about what you might interpret as a “class” dimension. Rather, the
elements of weight become class weights just because that’s how
the tensor elements line up in the element-wise computation after
broadcasting.

Note, if you modify this example to try to pass in a 1d tensor of sample
weights (of shape [nBatch]), as suggested by the documentation quoted
above, it won’t work.

Thus:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> nBatch = 2
>>> nClass = 3
>>> sample_weights = torch.rand (nBatch)
>>> class_weights = torch.rand (nClass)
>>> input = torch.randn (nBatch, nClass)
>>> target = torch.rand (nBatch, nClass)
>>> torch.nn.BCEWithLogitsLoss (weight = class_weights) (input, target)
tensor(0.5677)
>>> torch.nn.BCEWithLogitsLoss (weight = sample_weights) (input, target)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch>\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "<path_to_pytorch>\torch\nn\modules\loss.py", line 716, in forward
    reduction=self.reduction)
  File "<path_to_pytorch>\torch\nn\functional.py", line 2960, in binary_cross_entropy_with_logits
    return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

If you want sample weights (rather than class weights) in this use
case, you would have to unsqueeze() sample_weights so that
broadcasting lines the weights up the way you want:

>>> torch.nn.BCEWithLogitsLoss (weight = sample_weights.unsqueeze (1)) (input, target)
tensor(0.3120)

One last clarifying comment:

Although we often use BCEWithLogitsLoss with target values
(ground-truth labels) that are binary no-yes labels (expressed as
0.0-1.0 floating-point numbers), BCEWithLogitsLoss is more
general and accepts a probabilistic target whose elements are
floating-point values that run from 0.0 to 1.0 and represent the
probability that the sample in question is in class-“1”.

Is this more general case we don’t have samples that are purely
“negative” or “positive,” so, strictly speaking, pos_weight doesn’t
weight the “positive” samples. Rather, it weights the “positive” part of
the binary-cross-entropy formula used for each individual element-wise
loss computation.

An aside about terminology: This is not “one-hot” encoding (and, as a
rule of thumb, there’s never really any reason to use one-hot encoding
with pytorch). You have a multi-label use case and your sample labels
are “multi-hot encoded,” if you will.

The term “one-hot encoding” is often used imprecisely, but doing so
can be quite misleading. A one-hot encoded single-label class label
is a vector where exactly one element is 1 and all the others are 0.

Best.

K. Frank

2 Likes