Hi Ali!
Let me answer a couple of your specific questions first and then explain
how I look at it.
In this case the current (1.9.0) documentation for BCEWithLogitsLoss
is wrong (or at least misleading). Quoting:
weight (Tensor, optional) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.
On the contrary, weight
can have other shapes – for example, it can
have the same shape as input
. (As near as I can tell, the precise
requirement is that the shape of weight
and input
need to be
broadcastable.)
The documentation for the functional version,
binary_cross_entropy_with_logits(), comes closer to being correct:
weight (Tensor, optional) – a manual rescaling weight if provided it’s repeated to match input tensor shape
I think that what you are calling the “class weights tensor” is the
weight
tensor. (Any imagined difference would arise from how
you interpret them based on your use case.)
Please note that BCEWithLogitsLoss
takes four tensor arguments:
weight
, pos_weight
, input
, and target
. The first two are passed
in when BCEWithLogitsLoss
’s constructor is called to instantiate a
loss-function object, and the second two are passed in when the
resulting loss-function object is called. There is no separate “class
weights” argument.
Now to explain my understanding:
As far as I can tell, weight
, pos_weight
, input
, and target
need
only be broadcastable to one another. To simplify the discussion, let’s
assume that they are all of the same shape.
BCEWithLogitsLoss
doesn’t make any distinction, for example,
between labels / predictions for a specific class and specific samples
within a batch. It simply applies the BCEWithLogitsLoss
formula
on an element-wise basis, including the weight
and pos_weight
weightings, also on an element-wise basis. This produces a tensor
of element-wise loss values of the same shape as input
(and the
other arguments) that is then reduced (or not) according to the value
of reduction
.
Consider the following:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> nBatch = 2
>>> nClass = 3
>>> nSomethingElse = 5
>>> weight = torch.rand (nBatch, nClass, nSomethingElse)
>>> pos_weight = torch.rand (nBatch, nClass, nSomethingElse)
>>> input = torch.randn (nBatch, nClass, nSomethingElse)
>>> target = torch.rand (nBatch, nClass, nSomethingElse)
>>> torch.nn.BCEWithLogitsLoss (weight = weight, pos_weight = pos_weight) (input, target)
tensor(0.3176)
>>> torch.nn.BCEWithLogitsLoss (weight = weight.flatten(), pos_weight = pos_weight.flatten()) (input.flatten(), target.flatten())
tensor(0.3176)
BCEWithLogitsLoss
doesn’t care about any particular dimensions or
assign them meanings like “batch” or “class” or “height” or “width” – it
simply performs the (weighted) element-wise loss computation and
then reduces the result.
Where does the notion of a “class weights tensor” come from?
Consider a use case where we are performing a multi-label,
nClass
-class loss calculation for a batch of nBatch
samples:
Let input
and target
both have shape [nBatch, nClass]
. If weight
has shape [nClass]
it will be broadcast to match the shape of input
,
and, indeed, the elements of weight
will be class weights in the loss
calculation – but not because BCEWithLogitsLoss
knows or cares
about what you might interpret as a “class” dimension. Rather, the
elements of weight
become class weights just because that’s how
the tensor elements line up in the element-wise computation after
broadcasting.
Note, if you modify this example to try to pass in a 1d tensor of sample
weights (of shape [nBatch]
), as suggested by the documentation quoted
above, it won’t work.
Thus:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> nBatch = 2
>>> nClass = 3
>>> sample_weights = torch.rand (nBatch)
>>> class_weights = torch.rand (nClass)
>>> input = torch.randn (nBatch, nClass)
>>> target = torch.rand (nBatch, nClass)
>>> torch.nn.BCEWithLogitsLoss (weight = class_weights) (input, target)
tensor(0.5677)
>>> torch.nn.BCEWithLogitsLoss (weight = sample_weights) (input, target)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<path_to_pytorch>\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "<path_to_pytorch>\torch\nn\modules\loss.py", line 716, in forward
reduction=self.reduction)
File "<path_to_pytorch>\torch\nn\functional.py", line 2960, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1
If you want sample weights (rather than class weights) in this use
case, you would have to unsqueeze()
sample_weights
so that
broadcasting lines the weights up the way you want:
>>> torch.nn.BCEWithLogitsLoss (weight = sample_weights.unsqueeze (1)) (input, target)
tensor(0.3120)
One last clarifying comment:
Although we often use BCEWithLogitsLoss
with target
values
(ground-truth labels) that are binary no-yes labels (expressed as
0.0
-1.0
floating-point numbers), BCEWithLogitsLoss
is more
general and accepts a probabilistic target
whose elements are
floating-point values that run from 0.0
to 1.0
and represent the
probability that the sample in question is in class-“1”.
Is this more general case we don’t have samples that are purely
“negative” or “positive,” so, strictly speaking, pos_weight
doesn’t
weight the “positive” samples. Rather, it weights the “positive” part of
the binary-cross-entropy formula used for each individual element-wise
loss computation.
An aside about terminology: This is not “one-hot” encoding (and, as a
rule of thumb, there’s never really any reason to use one-hot encoding
with pytorch). You have a multi-label use case and your sample labels
are “multi-hot encoded,” if you will.
The term “one-hot encoding” is often used imprecisely, but doing so
can be quite misleading. A one-hot encoded single-label class label
is a vector where exactly one element is 1
and all the others are 0
.
Best.
K. Frank