Medical image segmentation

Hi, I am currently using the dice loss for colon segmentation in UNet but the network sometimes misses tiny details from the colon. So, I thought to add weights to the edges of the colon than other pixels (inside or out side colon).


This is my code for model:

class ColonModule(pl.LightningModule):
def init(self, config, segModel = None, pretrainedModel=None, in_channels=1):
super().init()

    self.save_hyperparameters(ignore=["pretrainedModel"])
    self.config = config
    self.pretrainedModel=pretrainedModel
    if self.pretrainedModel !=None :
        self.pretrainedModel.freeze()
        in_channels+=1


    self.model = segModel(
        encoder_name=config["encoder_name"],
        encoder_weights=config["encoder_weights"],
        in_channels=config["in_channels"],
        classes=1,
        activation=None,
    )

    self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=config["loss_smooth"])
    self.val_step_outputs = []
    self.val_step_labels = []


def forward(self, batch):
    imgs = batch
    
    if self.pretrainedModel !=None:
        self.pretrainedModel.eval()
        with torch.no_grad():
            initialMask = self.pretrainedModel(imgs)
            initialMask = torch.sigmoid(initialMask)
        
        imgMask = torch.cat((imgs, initialMask), 1)    
        preds = self.model(imgMask)
    else:
       preds = self.model(imgs) 
    # et = time.time()
    # print(f'time for forward path: {et-st}')
    return preds

def configure_optimizers(self):
    optimizer = AdamW(self.parameters(), **self.config["optimizer_params"])

    if self.config["scheduler"]["name"] == "CosineAnnealingLR":
        scheduler = CosineAnnealingLR(
            optimizer,
            **self.config["scheduler"]["params"]["CosineAnnealingLR"],
        )
        lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
    elif self.config["scheduler"]["name"] == "ReduceLROnPlateau":
        scheduler = ReduceLROnPlateau(
            optimizer,
            **self.config["scheduler"]["params"]["ReduceLROnPlateau"],
        )
        lr_scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}


def training_step(self, batch, batch_idx):
    
    imgs, labels,_ = batch
    # print(imgs.shape)
    
    if self.pretrainedModel !=None:
        self.pretrainedModel.eval()
        with torch.no_grad():
            initialMask = self.pretrainedModel(imgs)
            initialMask = torch.sigmoid(initialMask)
        imgMask = torch.cat((imgs, initialMask), 1)
        preds = self.model(imgMask)
    else:
       preds = self.model(imgs) 
    
    if self.config["image_size"] != 512:
        preds = torch.nn.functional.interpolate(preds, size=512, mode='bilinear')
    loss = self.loss_module(preds, labels)
    # print(loss)
    self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=8)

    for param_group in self.trainer.optimizers[0].param_groups:
        lr = param_group["lr"]
    self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)
    return loss

def validation_step(self, batch, batch_idx):
    imgs, labels,_ = batch
    # print((imgs.shape))
    if self.pretrainedModel !=None:
        initialMask = self.pretrainedModel(imgs)
        initialMask = torch.sigmoid(initialMask)
        imgMask = torch.cat((imgs, initialMask), 1)
        preds = self.model(imgMask)
    else:
       preds = self.model(imgs) 
    
    if self.config["image_size"] != 512:
        preds = torch.nn.functional.interpolate(preds, size=512, mode='bilinear')
    loss = self.loss_module(preds, labels)
    self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
    self.val_step_outputs.append(preds.cpu())
    self.val_step_labels.append(labels.cpu())

def on_validation_epoch_end(self):
    print(len(self.val_step_outputs))
    all_preds = torch.cat(self.val_step_outputs).float()
    all_labels = torch.cat(self.val_step_labels)

    all_preds = torch.sigmoid(all_preds)
    self.val_step_outputs.clear()
    self.val_step_labels.clear()
    # print(np.unique(all_labels.long().to('cpu').numpy()))
    val_dice = dice(all_preds, all_labels.long())
    self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
    # print("val_dice", val_dice)
    if self.trainer.global_rank == 0:
        print(f"\nEpoch: {self.current_epoch}", flush=True)

Hi Samir!

If I understand you correctly, you know in your ground-truth data where the
boundary of the colon is. I would suggest then that you perform three-class
semantic segmentation – not colon (i.e., background), interior of colon, and
boundary of colon.

Rather than add some weights to “encourage” the network to better learn
the edges of the colon, just train it outright to learn the boundary of the colon
as a third class.

(It might work for your use case or it might not, but I think it’s worth a try.)

I know that a lot of people advocate using dice loss, especially in the case of
unbalanced classes, and dice loss is popular in the medical-imaging community.
Nonetheless, my intuition remains (not based on any testing) that you should
generally prefer CrossEntropyLoss (or BCEWithLogitsLoss) for semantic
segmentation. My reasoning is that cross entropy has a logarithmic divergence
when your model’s prediction is highly wrong and, in my experience, this
divergence is very helpful for training (for reasons I don’t really understand).
For class imbalance, class weights generally work well.

I would recommend at least starting with CrossEntropyLoss (with class weights,
as appropriate) and if it’s not working as well as you would like (maybe because
of class imbalance), augment it with something like dice loss.

Best.

K. Frank