# Weighted pixelwise NLLLoss2d

Hi,

I am currently working on a segmentation problem, where my target is a segmentation mask with only 2 different classes (0 for background, 1 for object).

Until now I was using the `NLLLoss2d`, which works just fine, but I would like to add an additional pixelwise weighting to the object’s borders. I thought about creating a weight mask for each individual target, which will be calculated on the fly.

A similar solution was given in this thread: Pixelwise weights for MSELoss, so I tried to implement it for the `NLLLoss2d`.

Here is my approach:

``````# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)
weights = Variable(weights)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = weighted_loss.mean()
``````

I’m not sure about the rescaling part, where I sum the weighted log probabilities for the mini-batch and divide by the corresponding sum of weights (`weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)`).
Am I right in assuming that this should keep the log probs in approx. the same interval even if one target has a lot borders, therefore a higher weight?

Also, should I sum or average the log probs (`weighted_logp.sum(1) / weighted_logp.mean(1)`)?
I tried to search for the `F.nll_loss()` implementation but couldn’t find it.

Does this approach makes sense or am I missing something?

Greets!

6 Likes

this sort of looks good to me. sorry for the delayed response.

I noticed however, that I forgot to multiply with `-1` in the last line.

The last line should be:
weighted_loss = -1.0 * weighted_loss.mean()

Hi ptrblck! I am trying to use your solution for my unet implementation to compute a weighted loss that emphasizes boundary pixels. However, I do not entirely understand why you need to use logp.gather(1, target.view(batch_size, 1, H, W)). Can you please explain the reasoning behind it?

Cheers!

This operation gets the log probability of the corresponding target class for each sample in the batch.
It’s corresponding to the formula from the docs of `nn.CrossEntropyLoss`.

Thank you for your prompt response! I am also wondering about this line:

I understand that it is part of the loss formula, but I thought Softmax is computing the probabilities along one dimension. In the 2D-case this would mean along the channel dimension, right? Or are you computing the softmax along HxW? For me, it does not work: The log_softmax returns a tensor full of zeros with shape (B, C, H, W), because I have a pixel-wise two-class classification problem: background vs. foreground. Therefore the number of output channels and the number of target channels equals 1, in my case.

Am I missing something? I appreciate your help! ``````def lossWithWeightmap(self, logit_output, target, weight_map):
#logit_output has shape (B, C, H, W)
logSoftmaxOutput = nn.functional.log_softmax(logit_output)

#target has shape (B, 1, H, W)
target = target.type(torch.long)
logSoftmaxOutput = logSoftmaxOutput.gather(1, target.view(self.batchSize, 1, self.imgSize, self.imgSize))
weightedOutput = (logSoftmaxOutput* weight_map).view(self.batchSize, -1)
weightedLoss = weightedOutput.sum(1) / weight_map.view(self.batchSize, -1).sum(1)
weightedLoss = -1.0*weightedLoss.mean()
return weightedLoss
``````

Your output should have the shape `[batch_size, nb_classes=2, height, width]` for a two-class classification with CrossEntropyLoss.
Also, use `F.log_softmax(logit, dim=1)` to apply the log_softmax in the class dimension.

1 Like

Totally makes sense - thank you for helping and sharing your thoughts, @ptrblck!

1 Like

Hi Ptrblck, I’m using your loss as below, but it turns out

TypeError: softmax() received an invalid combination of arguments - got (Tensor), but expected one of: * (Tensor input, name dim, torch.dtype dtype)* (Tensor input, int dim, torch.dtype dtype)

my task is maybe the same as rajibmon, I changed the unet model to

Class UNet(nn.Module):

``````def __init__(self, in_channels=1,num_class =2,init_features=32):
super(UNet, self).__init__()

features = init_features
self.num_classes = num_class
# self.input = FCN._block(in_channels, features, name="input")
# self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder1 = UNet._block(in_channels, features * 1, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(features * 1, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(features * 4, features * 8, name='enc4')
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

self.bottleneck = UNet._block(features * 8, features * 16, name='bottleneck')

self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
self.decoder4 = UNet._block((features * 8)*2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
self.decoder3 = UNet._block((features * 4)*2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
self.decoder2 = UNet._block((features * 2)*2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
self.decoder1 = UNet._block(features*2, features, name="dec1")

self.conv = nn.Conv2d(in_channels=features, out_channels=num_class, kernel_size=1)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck  = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4,enc4),dim =1)
dec4 = F.dropout(self.decoder4(dec4),0.5,training=True)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3,enc3),dim=1)
dec3 = F.dropout(self.decoder3(dec3),0.5,training=True)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2,enc2),dim=1)
dec2 = F.dropout(self.decoder2(dec2),0.5,training=True)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1,enc1),dim=1)
dec1 = F.dropout(self.decoder1(dec1),0.5,training=True)
out = self.conv(dec1)
out = torch.softmax(out)
return out
``````

loss is now defined as

def weightedpixelcros(bseg,bgt,bs):
out = F.log_softmax(bseg)
weighs = torch.FloatTensor(bs,1,256,256).random_(1,3)
out = out.gather(1, bgt.view(bs, 1, 256,256))
wout = (outweighs).view(bs,-1)
wgloss = wout.sum(1) / weighs.view(bs, -1).sum(1)
wgloss = -1.0
wgloss.mean()
return wgloss

net = UNet(in_channels=1, num_class=2, init_features=64)
net.to(device)
bone_seg = net(input_mr)