Using Pixel Weight Map with BCEWithLogitsLoss


I am trying to implement a BCEWithLogitsLoss version with a different weight for some pixels. The goal is to have good borders in a segmentation task.

The outputs of my net are logits.

My weight map have the same size as my target. (For example [batch_size, height, width]).

My implementation currently looks like this :

def BCE_with_logits_loss_pixel_map(pred, target, pix_map):
    # check that pred, target and map have the same dimension
    assert pred.size() == target.size() == pix_map.size(), "Loss Components don't have the same size"
    sig_pred = torch.sigmoid(pred)
    unweighted_loss = target * torch.log(sig_pred) + (1 - target) * torch.log(1 - sig_pred)
    weighted_loss = -(pix_map * unweighted_loss)
    mean_weighted_loss = weighted_loss.mean()
    return mean_weighted_loss

batch_size = 8
out_channels = 1
W = 128
H = 128
logits = torch.FloatTensor(batch_size, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, H, W).random_(1, 3)

loss_val = BCE_with_logits_loss_pixel_map(logits, target, weights)

I would like to know if there are any mistake in my code ?

Thank you.

Hi John!

I would recommend that you use BCEWithLogitsLoss's weight

    mean_weighted_loss = torch.nn.BCEWithLogitsLoss (weight = pix_map) (pred, target)

As a further comment: You can certainly write your own version of
BCEWithLogitsLoss, but if you do, you should, for reasons of
numerical stability, use logsigmoid() instead of sigmoid():

    log_sig_pred = torch.nn.functional.logsigmoid (pred)
    log_sig_neg_pred = torch.nn.functional.logsigmoid (-pred)
    unweighted_loss = target * log_sig_pred + (1 - target) * log_sig_neg_pred


K. Frank

1 Like

Thank you very much for your help K. Frank,
I didn’t saw the weight parameter in the torch.nn.BCEWithLogitsLoss function.
It sure is easier than to do my own function.
Also thanks for the tip regarding the numerical stability for sigmoid.
Have a great day !