The short story is that pos_weight has to broadcast with respect to output
and target, and in your case it doesn’t.
Consider:
>>> import torch
>>> torch.__version__
'1.11.0'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> target = torch.ones([10, 64, 20, 20], dtype=torch.float32) # 64 classes, batch size = 10
>>> output = torch.full([10, 64, 20, 20], 1.5) # A prediction (logit)
>>>
>>> torch.nn.BCEWithLogitsLoss() (output, target)
tensor(0.2014)
>>> torch.nn.BCEWithLogitsLoss (pos_weight = torch.ones ([64, 1, 1])) (output, target)
tensor(0.2014)
>>> torch.nn.BCEWithLogitsLoss (pos_weight = torch.ones ([64, 20, 1])) (output, target)
tensor(0.2014)
>>> torch.nn.BCEWithLogitsLoss (pos_weight = torch.ones ([64, 20, 20])) (output, target)
tensor(0.2014)
>>> torch.nn.BCEWithLogitsLoss (pos_weight = torch.tensor ([1.])) (output, target)
tensor(0.2014)
>>> torch.nn.BCEWithLogitsLoss (pos_weight = torch.ones ([64])) (output, target)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<path_to_pytorch_install>\torch\nn\modules\module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "<path_to_pytorch_install>\torch\nn\modules\loss.py", line 713, in forward
return F.binary_cross_entropy_with_logits(input, target,
File "<path_to_pytorch_install>\torch\nn\functional.py", line 3132, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: The size of tensor a (64) must match the size of tensor b (20) at non-singleton dimension 3
*) The BCEWithLogitsLoss documentation is quite elliptical on this issue,
but, as near as I can tell (from experiments rather than the documentation), BCEWithLogitsLoss doesn’t actually have any notion of “batch size”
or “number of classes” or assign any particular meaning to any of the
dimensions. Rather, output, target, weight, and pos_weight all
have to broadcast with respect to one another. It then just performs an
element-wise loss computation which gets reduced.