U-net segmentation (predict single class) is stuck at constant loss

Hi all, I am new to Unets, have read tutorials and implementations online and tried to make my own.
Currently, I’m trying to predict a biomedical image dataset with a binary (0,255) ground truth mask (I preprocessed it to be as such) and a medical image, both of the same size. However I have tried several options but my network is not learning and is stuck at a constant loss throughout. Here is my code (i removed the file paths for the image and mask, and tried to overfit on a single image but loss is still constant):

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split

import os
from PIL import Image
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

#use CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#image and mask relative paths
image_dpath=''
mask_dpath=''

# Unet model
def double_conv(in_c,out_c):
    conv=nn.Sequential(
        nn.Conv2d(in_c,out_c,kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c,out_c,kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

def crop_img(tensor,target_tensor):
    target_size=target_tensor.size()[2]
    tensor_size=tensor.size()[2]
    delta=tensor_size-target_size
    delta=delta//2
    return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]
    

class Unet(nn.Module):
    def __init__(self, n_class):
        super(Unet,self).__init__()
        
        self.max_pool_2x2=nn.MaxPool2d(kernel_size=2,stride=2)
        self.down_conv_1=double_conv(1,64)
        self.down_conv_2=double_conv(64,128)
        self.down_conv_3=double_conv(128,256)
        self.down_conv_4=double_conv(256,512)
        self.down_conv_5=double_conv(512,1024)
        
        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        self.up_conv_1 =double_conv(1024,512)
        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
        self.up_conv_2=double_conv(512,256)
        self.up_trans_3=nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        self.up_conv_3=double_conv(256,128)
        self.up_trans_4=nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)
        self.up_conv_4=double_conv(128,64)
        
        self.out=nn.Conv2d(in_channels=64,out_channels=n_class,kernel_size=1)
        
    def forward(self,img):
        #encoder
        x1=self.down_conv_1(img)
        x2=self.max_pool_2x2(x1)
        
        x3=self.down_conv_2(x2)
        x4=self.max_pool_2x2(x3)
        
        x5=self.down_conv_3(x4)
        x6=self.max_pool_2x2(x5)
        
        x7=self.down_conv_4(x6)
        x8=self.max_pool_2x2(x7)
        
        x9=self.down_conv_5(x8)
        
        #decoder
        x=self.up_trans_1(x9)
        y=crop_img(x7,x)
        x=self.up_conv_1(torch.cat([x,y],1))
        
        x=self.up_trans_2(x)
        y=crop_img(x5,x)
        x=self.up_conv_2(torch.cat([x,y],1))
        
        x=self.up_trans_3(x)
        y=crop_img(x3,x)
        x=self.up_conv_3(torch.cat([x,y],1))
        
        x=self.up_trans_4(x)
        y=crop_img(x1,x)
        x=self.up_conv_4(torch.cat([x,y],1))
        
        x=self.out(x)
        
        return x

# Dice loss function
def dice_coeff(pred,target):
    eps = 0.0001 #prevent division by zero
    inter = torch.dot(pred.reshape(-1), target.reshape(-1))
    union = torch.sum(pred) + torch.sum(target) + eps
    t = (2.0 * inter.float() + eps) / union.float()
    return t

def calc_loss(pred, target, bce_weight=0.5):

    bce = F.binary_cross_entropy_with_logits(pred, target)
    pred = torch.sigmoid(pred)
    dice = dice_coeff(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss

#Dataset class
class img_label(Dataset):
    def __init__(self, img_dpath, mask_dpath):
        self.img_dpath = img_dpath
        self.mask_dpath = mask_dpath

        self.img_ids = [file for file in glob(img_dpath + '\\*.png')]
        self.mask_ids = [file for file in glob(mask_dpath + '\\*.png')]

    @classmethod
    def preprocess(cls,PIL_img):
        w,h = PIL_img.size #w = 512, h = 496
        add_h = int((572 - h)/2)
        add_w = int((572 - w)/2)
        x = np.asarray(PIL_img)  
        x = np.pad(x,((add_h,add_h),(add_w,add_w)),mode='constant')
        pic_mean = np.mean(x)
        pic_std = np.std(x)
        x = (x-pic_mean)/pic_std
        return x
    
    def __getitem__(self,idx):
        imgpath = self.img_ids[idx]
        maskpath = self.mask_ids[idx]
       #note: the images need to be opened by PIL.Image.open(path) in order for it to be preprocessed. 
        img = Image.open(imgpath)
        img = self.preprocess(img)

        mask = Image.open(maskpath)
        mask = self.preprocess(mask)

        return (torch.from_numpy(img).type(torch.FloatTensor).unsqueeze(0).cuda(), torch.from_numpy(mask).type(torch.FloatTensor).unsqueeze(0).cuda())

    def __len__(self):
        return len(self.img_ids)

#Create dataset
dataset = img_label(image_dpath,mask_dpath)

#Create train and validation dataloaders
train,val = random_split(dataset,[1,1])
batch_size=1
train_loader = DataLoader(train, batch_size=batch_size,shuffle=True,num_workers=0)
val_loader = DataLoader(val,batch_size=batch_size,shuffle=True,num_workers=0)

# hyperparameters
n_class = 1 #we are predicting 1 class
learning_rate = 1e-5
optimizer = optim.Adam(Unet(n_class=n_class).parameters(),lr=learning_rate)
model = Unet(n_class=n_class).to(device=device)

# training loop
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader,val_loader):
    for epoch in range(1,n_epochs+1):
        loss_train = 0.0
        a=0
        for imgs,masks in train_loader:

            output = model(imgs)

            target_tensor = torch.randn((batch_size,1,388,388))
            target = crop_img(masks,target_tensor)
            
            optimizer.zero_grad()
            loss = calc_loss(output,target)

            loss.backward()
            optimizer.step()

            loss_train += loss.item()

            trans = transforms.ToPILImage()

            fig = plt.figure()
            plt.imshow(trans(output.cpu().squeeze()))
            
            fig.savefig(f'{a}.png')
            a+=1

        if epoch <= 100:
            print(f"Epoch: {epoch}, Training loss: {loss_train/len(train_loader)}")

training_loop(n_epochs=100,optimizer=optimizer,model=model,loss_fn=dice_coeff,train_loader=train_loader,val_loader=val_loader)

Do you get constant predictions?
I’m asking because one of the things to keep in mind is that if you have very few “interesting” pixels in a large number of uninteresting ones, you essentially have a heavily imbalanced problem, and that can hinder training. To solve this, one could balance the pixels, e.g. when training the U-Net in our book, we take care that we have enough slices with nodules fed into the U-Net. A brief discussion is in section 13.5.5 (Designing our training and validation data).

Best regards

Thomas

I currently get constant loss,
I’ve changed the loss function to nn.crossentropyloss() (https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) and therefore compared the two outputs channels of the last layer of the U-net to the target which is just a single channel binary (0,1) ground truth image. However still same loss being stuck,
The ground truth image has a significant amount of '1’s (white), maybe about 1/20 of the image, as an estimate… so I might be able to rule out class imbalance…
cropping might be useful but I am not sure if there are any other useful methods that will work as well

I realized that I used a wrong loss function (manually implemented dice loss), which could have been undifferentiable and autograd might not have been able to backprop the loss (just a theory). I used nn.BCELoss() instead and it works now.