BCEWithLogits with weight/pos_weight

Hello, I have read several topics about setting the weight to a loss, but I have some interesting to me question.

So I have a binary segmentation problem, with classes 0 – background, and 1 – buildings. And they are unbalanced. I decided to set a weight for BCEWithLogits loss with torch.tensor([0.3, 0.7]) for class 0 and 1, respectively. But when I try to calculate the loss, script throws a Runtime error:

import torch
output = torch.randn(1, 2, 256, 256, requires_grad=True)
target = torch.randn(1, 2, 256, 256, requires_grad=False)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([0.3, 0.7])) 
# same when I use 'weight' insted pos_weight
loss = criterion(output, target)

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-0ea142ac673d> in <module>
      3 target = torch.randn(1, 2, 256, 256, requires_grad=False)
      4 criterion = torch.nn.BCEWithLogitsLoss(weight=torch.tensor([0.3, 0.7]))
----> 5 loss = criterion(output, target)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    615                                                   self.weight,
    616                                                   pos_weight=self.pos_weight,
--> 617                                                   reduction=self.reduction)
    618 
    619 

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2433         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2434 
-> 2435     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
   2436 
   2437 

RuntimeError: The size of tensor a (256) must match the size of tensor b (2) at non-singleton dimension 3

But when I try only one tensor with shape 1, it turns out that it works

import torch
output = torch.randn(1, 2, 256, 256, requires_grad=True)
target = torch.randn(1, 2, 256, 256, requires_grad=False)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([0.3]))
loss = criterion(output, target)

print(loss)

>>> tensor(0.2424, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

My question is why the first method didn’t work, and the second worked. And if the second method is the correct one, then which class will get correct weight to get better loss?

Thank you!

A binary segmentation use case with nn.BCEWithLogitsLoss would expect the model output and the target to have the shape [batch_size, 1, height, weight] where a low logit would indicate class0 and a high logit would indicate class1.
For this use case the pos_weight should also be a single value as described in the docs.

The error message is still surprising and I would guess PyTorch tries to broadcast the pos_weight in your code. For a multi-class segmentation, where each pixel can have 0, 1, or multiple active classes you might need to unsqueeze the pos_weight in dim2 and dim3:

output = torch.randn(1, 2, 256, 256, requires_grad=True)
target = torch.randn(1, 2, 256, 256, requires_grad=False)
pos_weight = torch.tensor([0.3, 0.7])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight[None, :, None, None])
loss = criterion(output, target)