BCELoss with class weights

I’m doing an image segmentation task. It’s a binary case. That is, the target pixels are either 0 (not of the class) or 1 (belong to the class).

I’m using BCELoss as the loss function. I’m using BCE instead of BCEWithLogits because my model already has a sigmoid at the end.

My dataset is quite unbalanced. Of all the pixels, only a small percentage belong to the target class. Hence, I’d like to use class weights.

I’ve tried to do this:

class_weights = torch.tensor([0.5, 74.0])
criterion = nn.BCELoss(weight=class_weights)

loss = criterion(predictions, label)

but I get an error:
RuntimeError: The size of tensor a (352) must match the size of tensor b (2) at non-singleton dimension 3.

How can I do this? From what I’ve gathered, weights is expected to have the same dimensions as the batch. This is doing batch weighting, but what I actually want is class weighting. From what I understand, BCEWithLogits does have both weights and pos_weights, but I would like to avoid changing the criterion.

I guess there might be workarounds using nn.BCELoss with the sample weighting by computing the weight per batch using the class distribution, but I would rather switch to nn.BCEWithLogitsLoss allowing you to use pos_weights directly and for better numerical stability.

I guess I don’t quite understand the difference between weights and pos_weights for the loss, but pos_weights are the class weights, is that it?

So, I should:

  1. Remove the sigmoid layer at the end of my model;
  2. Apply the class weights as below;
  3. Apply a sigmoid at my predictions when inferring
    Is this correct?
class_weights = torch.tensor([74.0 / 0.5], dtype=torch.float32)
criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

Also, if I can ask, is this the right way to calculate class weights? I have

no_class_count = 480473341       # Number of pixels with no_class
class_count = 3279159                 # Number of pixels with class
total_samples = no_class_count + class_count
no_class_weight = total_samples / (2 * no_kelp_count)    # 0.5
class_weight = total_samples / (2 * kelp_count)                # 0.74

Also, for the binary case I shouldn’t have what I had previously and instead have as below, is this correct?

class_weights = torch.tensor([0.5, 74.0])   # Raises error due to dimensions
class_weights = torch.tensor([74.0 / 0.5])

Yes, your approach looks generally correct.

A simple choice for pos_weights would be num_negatives / num_positives as it will weight the positive samples so that they have approx. the same influence on the loss.

I had a doubt, I am using focal loss to tackle class imbalance. And I do not have a sigmoid layer at the end of my model. So I should be using BCEWithLogits correct?
Also, these class weights are obtained using the inverse class frequencies:
This is my focal loss implementation:
class FocalLoss(nn.Module):
def init(self, alpha=0.25, gamma=2.0, reduction=‘mean’):
super(FocalLoss, self).init()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction

def forward(self, inputs, targets):
    targets = F.one_hot(targets, num_classes=2).float()
    inputs = torch.sigmoid(inputs)  
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    class_weights = torch.tensor([1.11/10.04], dtype=torch.float32)
    BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='mean', weight=class_weights)
    pt = torch.exp(-BCE_loss)
    focal_loss = self.alpha * ((1 - pt) ** self.gamma) * BCE_loss
    if self.reduction == 'mean':
        return torch.mean(focal_loss)
    elif self.reduction == 'sum':
        return torch.sum(focal_loss)
    elif self.reduction == 'none':
        return focal_loss
        raise ValueError("Invalid reduction mode. options: 'mean', 'sum', or 'none'.")

You can use the loss provided by Segmentation Models specifing it inside the _init_ :

self.loss_fn = smp.losses.FocalLoss(mode=smp.losses.BINARY_MODE)

Then, inside your step function (train, validation, test) you can calculate the loss as follows:

logits = self.forward(image)
loss = self.loss_fn(logits, ground_truth)