Weighted pixelwise NLLLoss2d

Hi,

I am currently working on a segmentation problem, where my target is a segmentation mask with only 2 different classes (0 for background, 1 for object).

Until now I was using the NLLLoss2d, which works just fine, but I would like to add an additional pixelwise weighting to the object’s borders. I thought about creating a weight mask for each individual target, which will be calculated on the fly.

A similar solution was given in this thread: Pixelwise weights for MSELoss, so I tried to implement it for the NLLLoss2d.

Here is my approach:

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)
weights = Variable(weights)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = weighted_loss.mean()

I’m not sure about the rescaling part, where I sum the weighted log probabilities for the mini-batch and divide by the corresponding sum of weights (weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)).
Am I right in assuming that this should keep the log probs in approx. the same interval even if one target has a lot borders, therefore a higher weight?

Also, should I sum or average the log probs (weighted_logp.sum(1) / weighted_logp.mean(1))?
I tried to search for the F.nll_loss() implementation but couldn’t find it.

Does this approach makes sense or am I missing something?

Greets!

6 Likes

this sort of looks good to me. sorry for the delayed response.

Thanks for the reply!
I noticed however, that I forgot to multiply with -1 in the last line.

The last line should be:
weighted_loss = -1.0 * weighted_loss.mean()

Hi ptrblck! I am trying to use your solution for my unet implementation to compute a weighted loss that emphasizes boundary pixels. However, I do not entirely understand why you need to use logp.gather(1, target.view(batch_size, 1, H, W)). Can you please explain the reasoning behind it?

Cheers!

This operation gets the log probability of the corresponding target class for each sample in the batch.
It’s corresponding to the formula from the docs of nn.CrossEntropyLoss.

Thank you for your prompt response! :slight_smile: I am also wondering about this line:

I understand that it is part of the loss formula, but I thought Softmax is computing the probabilities along one dimension. In the 2D-case this would mean along the channel dimension, right? Or are you computing the softmax along HxW? For me, it does not work: The log_softmax returns a tensor full of zeros with shape (B, C, H, W), because I have a pixel-wise two-class classification problem: background vs. foreground. Therefore the number of output channels and the number of target channels equals 1, in my case.

Am I missing something? I appreciate your help! :slight_smile:

def lossWithWeightmap(self, logit_output, target, weight_map):
    #logit_output has shape (B, C, H, W)
	logSoftmaxOutput = nn.functional.log_softmax(logit_output)

    #target has shape (B, 1, H, W)	
    target = target.type(torch.long)
	logSoftmaxOutput = logSoftmaxOutput.gather(1, target.view(self.batchSize, 1, self.imgSize, self.imgSize))
	weightedOutput = (logSoftmaxOutput* weight_map).view(self.batchSize, -1)
	weightedLoss = weightedOutput.sum(1) / weight_map.view(self.batchSize, -1).sum(1)
	weightedLoss = -1.0*weightedLoss.mean()
	return weightedLoss

Your output should have the shape [batch_size, nb_classes=2, height, width] for a two-class classification with CrossEntropyLoss.
Also, use F.log_softmax(logit, dim=1) to apply the log_softmax in the class dimension.

1 Like

Totally makes sense - thank you for helping and sharing your thoughts, @ptrblck!

1 Like

Hi Ptrblck, I’m using your loss as below, but it turns out

TypeError: softmax() received an invalid combination of arguments - got (Tensor), but expected one of: * (Tensor input, name dim, torch.dtype dtype)* (Tensor input, int dim, torch.dtype dtype)

my task is maybe the same as rajibmon, I changed the unet model to

Class UNet(nn.Module):

def __init__(self, in_channels=1,num_class =2,init_features=32):
    super(UNet, self).__init__()

    features = init_features
    self.num_classes = num_class
    # self.input = FCN._block(in_channels, features, name="input")
    # self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder1 = UNet._block(in_channels, features * 1, name="enc1")
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder2 = UNet._block(features * 1, features * 2, name="enc2")
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
    self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder4 = UNet._block(features * 4, features * 8, name='enc4')
    self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.bottleneck = UNet._block(features * 8, features * 16, name='bottleneck')


    self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
    self.decoder4 = UNet._block((features * 8)*2, features * 8, name="dec4")
    self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
    self.decoder3 = UNet._block((features * 4)*2, features * 4, name="dec3")
    self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
    self.decoder2 = UNet._block((features * 2)*2, features * 2, name="dec2")
    self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
    self.decoder1 = UNet._block(features*2, features, name="dec1")

    self.conv = nn.Conv2d(in_channels=features, out_channels=num_class, kernel_size=1)
def forward(self, x):
    enc1 = self.encoder1(x)
    enc2 = self.encoder2(self.pool1(enc1))
    enc3 = self.encoder3(self.pool2(enc2))
    enc4 = self.encoder4(self.pool3(enc3))
    bottleneck  = self.bottleneck(self.pool4(enc4))
    dec4 = self.upconv4(bottleneck)
    dec4 = torch.cat((dec4,enc4),dim =1)
    dec4 = F.dropout(self.decoder4(dec4),0.5,training=True)
    dec3 = self.upconv3(dec4)
    dec3 = torch.cat((dec3,enc3),dim=1)
    dec3 = F.dropout(self.decoder3(dec3),0.5,training=True)
    dec2 = self.upconv2(dec3)
    dec2 = torch.cat((dec2,enc2),dim=1)
    dec2 = F.dropout(self.decoder2(dec2),0.5,training=True)
    dec1 = self.upconv1(dec2)
    dec1 = torch.cat((dec1,enc1),dim=1)
    dec1 = F.dropout(self.decoder1(dec1),0.5,training=True)
    out = self.conv(dec1)
    out = torch.softmax(out)
    return out

loss is now defined as

def weightedpixelcros(bseg,bgt,bs):
out = F.log_softmax(bseg)
weighs = torch.FloatTensor(bs,1,256,256).random_(1,3)
out = out.gather(1, bgt.view(bs, 1, 256,256))
wout = (outweighs).view(bs,-1)
wgloss = wout.sum(1) / weighs.view(bs, -1).sum(1)
wgloss = -1.0
wgloss.mean()
return wgloss

net = UNet(in_channels=1, num_class=2, init_features=64)
net.to(device)
bone_seg = net(input_mr)
optimizer.zero_grad()
loss = weightedpixelcros(bone_gt, bone_seg, bs=BATCH_SIZE)
loss.backward()

Could you please help me to have a look if I’m doing something wrong??

The error was caused by “output = torch.softmax(output)”

What’s that line of code for? You’ve already got “F.kog_softmax()” in your weightedpixelcros function.

@ptrblck
Hi, I am trying to do something similar, weigh pixels in cross-entropy loss for semantic segmentation.

I used your implementation as well as tried

Hi, I am currently using the dice loss for colon segmentation in UNet but the network sometimes misses tiny details from the colon. So, I thought to add weights to the edges of the colon than other pixels (inside or out side colon).


This is my code for model:

class ColonModule(pl.LightningModule):
def init(self, config, segModel = None, pretrainedModel=None, in_channels=1):
super().init()

    self.save_hyperparameters(ignore=["pretrainedModel"])
    self.config = config
    self.pretrainedModel=pretrainedModel
    if self.pretrainedModel !=None :
        self.pretrainedModel.freeze()
        in_channels+=1


    self.model = segModel(
        encoder_name=config["encoder_name"],
        encoder_weights=config["encoder_weights"],
        in_channels=config["in_channels"],
        classes=1,
        activation=None,
    )

    self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=config["loss_smooth"])
    self.val_step_outputs = []
    self.val_step_labels = []


def forward(self, batch):
    imgs = batch
    
    if self.pretrainedModel !=None:
        self.pretrainedModel.eval()
        with torch.no_grad():
            initialMask = self.pretrainedModel(imgs)
            initialMask = torch.sigmoid(initialMask)
        
        imgMask = torch.cat((imgs, initialMask), 1)    
        preds = self.model(imgMask)
    else:
       preds = self.model(imgs) 
    # et = time.time()
    # print(f'time for forward path: {et-st}')
    return preds

def configure_optimizers(self):
    optimizer = AdamW(self.parameters(), **self.config["optimizer_params"])

    if self.config["scheduler"]["name"] == "CosineAnnealingLR":
        scheduler = CosineAnnealingLR(
            optimizer,
            **self.config["scheduler"]["params"]["CosineAnnealingLR"],
        )
        lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
    elif self.config["scheduler"]["name"] == "ReduceLROnPlateau":
        scheduler = ReduceLROnPlateau(
            optimizer,
            **self.config["scheduler"]["params"]["ReduceLROnPlateau"],
        )
        lr_scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}


def training_step(self, batch, batch_idx):
    
    imgs, labels,_ = batch
    # print(imgs.shape)
    
    if self.pretrainedModel !=None:
        self.pretrainedModel.eval()
        with torch.no_grad():
            initialMask = self.pretrainedModel(imgs)
            initialMask = torch.sigmoid(initialMask)
        imgMask = torch.cat((imgs, initialMask), 1)
        preds = self.model(imgMask)
    else:
       preds = self.model(imgs) 
    
    if self.config["image_size"] != 512:
        preds = torch.nn.functional.interpolate(preds, size=512, mode='bilinear')
    loss = self.loss_module(preds, labels)
    # print(loss)
    self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=8)

    for param_group in self.trainer.optimizers[0].param_groups:
        lr = param_group["lr"]
    self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)
    return loss

def validation_step(self, batch, batch_idx):
    imgs, labels,_ = batch
    # print((imgs.shape))
    if self.pretrainedModel !=None:
        initialMask = self.pretrainedModel(imgs)
        initialMask = torch.sigmoid(initialMask)
        imgMask = torch.cat((imgs, initialMask), 1)
        preds = self.model(imgMask)
    else:
       preds = self.model(imgs) 
    
    if self.config["image_size"] != 512:
        preds = torch.nn.functional.interpolate(preds, size=512, mode='bilinear')
    loss = self.loss_module(preds, labels)
    self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
    self.val_step_outputs.append(preds.cpu())
    self.val_step_labels.append(labels.cpu())

def on_validation_epoch_end(self):
    print(len(self.val_step_outputs))
    all_preds = torch.cat(self.val_step_outputs).float()
    all_labels = torch.cat(self.val_step_labels)

    all_preds = torch.sigmoid(all_preds)
    self.val_step_outputs.clear()
    self.val_step_labels.clear()
    # print(np.unique(all_labels.long().to('cpu').numpy()))
    val_dice = dice(all_preds, all_labels.long())
    self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
    # print("val_dice", val_dice)
    if self.trainer.global_rank == 0:
        print(f"\nEpoch: {self.current_epoch}", flush=True)