Hi, I am currently using the dice loss for colon segmentation in UNet but the network sometimes misses tiny details from the colon. So, I thought to add weights to the edges of the colon than other pixels (inside or out side colon).
This is my code for model:
class ColonModule(pl.LightningModule):
def init(self, config, segModel = None, pretrainedModel=None, in_channels=1):
super().init()
self.save_hyperparameters(ignore=["pretrainedModel"])
self.config = config
self.pretrainedModel=pretrainedModel
if self.pretrainedModel !=None :
self.pretrainedModel.freeze()
in_channels+=1
self.model = segModel(
encoder_name=config["encoder_name"],
encoder_weights=config["encoder_weights"],
in_channels=config["in_channels"],
classes=1,
activation=None,
)
self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=config["loss_smooth"])
self.val_step_outputs = []
self.val_step_labels = []
def forward(self, batch):
imgs = batch
if self.pretrainedModel !=None:
self.pretrainedModel.eval()
with torch.no_grad():
initialMask = self.pretrainedModel(imgs)
initialMask = torch.sigmoid(initialMask)
imgMask = torch.cat((imgs, initialMask), 1)
preds = self.model(imgMask)
else:
preds = self.model(imgs)
# et = time.time()
# print(f'time for forward path: {et-st}')
return preds
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), **self.config["optimizer_params"])
if self.config["scheduler"]["name"] == "CosineAnnealingLR":
scheduler = CosineAnnealingLR(
optimizer,
**self.config["scheduler"]["params"]["CosineAnnealingLR"],
)
lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
elif self.config["scheduler"]["name"] == "ReduceLROnPlateau":
scheduler = ReduceLROnPlateau(
optimizer,
**self.config["scheduler"]["params"]["ReduceLROnPlateau"],
)
lr_scheduler = {"scheduler": scheduler, "monitor": "val_loss"}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
def training_step(self, batch, batch_idx):
imgs, labels,_ = batch
# print(imgs.shape)
if self.pretrainedModel !=None:
self.pretrainedModel.eval()
with torch.no_grad():
initialMask = self.pretrainedModel(imgs)
initialMask = torch.sigmoid(initialMask)
imgMask = torch.cat((imgs, initialMask), 1)
preds = self.model(imgMask)
else:
preds = self.model(imgs)
if self.config["image_size"] != 512:
preds = torch.nn.functional.interpolate(preds, size=512, mode='bilinear')
loss = self.loss_module(preds, labels)
# print(loss)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=8)
for param_group in self.trainer.optimizers[0].param_groups:
lr = param_group["lr"]
self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
imgs, labels,_ = batch
# print((imgs.shape))
if self.pretrainedModel !=None:
initialMask = self.pretrainedModel(imgs)
initialMask = torch.sigmoid(initialMask)
imgMask = torch.cat((imgs, initialMask), 1)
preds = self.model(imgMask)
else:
preds = self.model(imgs)
if self.config["image_size"] != 512:
preds = torch.nn.functional.interpolate(preds, size=512, mode='bilinear')
loss = self.loss_module(preds, labels)
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.val_step_outputs.append(preds.cpu())
self.val_step_labels.append(labels.cpu())
def on_validation_epoch_end(self):
print(len(self.val_step_outputs))
all_preds = torch.cat(self.val_step_outputs).float()
all_labels = torch.cat(self.val_step_labels)
all_preds = torch.sigmoid(all_preds)
self.val_step_outputs.clear()
self.val_step_labels.clear()
# print(np.unique(all_labels.long().to('cpu').numpy()))
val_dice = dice(all_preds, all_labels.long())
self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
# print("val_dice", val_dice)
if self.trainer.global_rank == 0:
print(f"\nEpoch: {self.current_epoch}", flush=True)