Hi all, I am new to Unets, have read tutorials and implementations online and tried to make my own.
Currently, I’m trying to predict a biomedical image dataset with a binary (0,255) ground truth mask (I preprocessed it to be as such) and a medical image, both of the same size. However I have tried several options but my network is not learning and is stuck at a constant loss throughout. Here is my code (i removed the file paths for the image and mask, and tried to overfit on a single image but loss is still constant):
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
import os
from PIL import Image
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
#use CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#image and mask relative paths
image_dpath=''
mask_dpath=''
# Unet model
def double_conv(in_c,out_c):
conv=nn.Sequential(
nn.Conv2d(in_c,out_c,kernel_size=3),
nn.ReLU(inplace=True),
nn.Conv2d(out_c,out_c,kernel_size=3),
nn.ReLU(inplace=True)
)
return conv
def crop_img(tensor,target_tensor):
target_size=target_tensor.size()[2]
tensor_size=tensor.size()[2]
delta=tensor_size-target_size
delta=delta//2
return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]
class Unet(nn.Module):
def __init__(self, n_class):
super(Unet,self).__init__()
self.max_pool_2x2=nn.MaxPool2d(kernel_size=2,stride=2)
self.down_conv_1=double_conv(1,64)
self.down_conv_2=double_conv(64,128)
self.down_conv_3=double_conv(128,256)
self.down_conv_4=double_conv(256,512)
self.down_conv_5=double_conv(512,1024)
self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
self.up_conv_1 =double_conv(1024,512)
self.up_trans_2 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
self.up_conv_2=double_conv(512,256)
self.up_trans_3=nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
self.up_conv_3=double_conv(256,128)
self.up_trans_4=nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)
self.up_conv_4=double_conv(128,64)
self.out=nn.Conv2d(in_channels=64,out_channels=n_class,kernel_size=1)
def forward(self,img):
#encoder
x1=self.down_conv_1(img)
x2=self.max_pool_2x2(x1)
x3=self.down_conv_2(x2)
x4=self.max_pool_2x2(x3)
x5=self.down_conv_3(x4)
x6=self.max_pool_2x2(x5)
x7=self.down_conv_4(x6)
x8=self.max_pool_2x2(x7)
x9=self.down_conv_5(x8)
#decoder
x=self.up_trans_1(x9)
y=crop_img(x7,x)
x=self.up_conv_1(torch.cat([x,y],1))
x=self.up_trans_2(x)
y=crop_img(x5,x)
x=self.up_conv_2(torch.cat([x,y],1))
x=self.up_trans_3(x)
y=crop_img(x3,x)
x=self.up_conv_3(torch.cat([x,y],1))
x=self.up_trans_4(x)
y=crop_img(x1,x)
x=self.up_conv_4(torch.cat([x,y],1))
x=self.out(x)
return x
# Dice loss function
def dice_coeff(pred,target):
eps = 0.0001 #prevent division by zero
inter = torch.dot(pred.reshape(-1), target.reshape(-1))
union = torch.sum(pred) + torch.sum(target) + eps
t = (2.0 * inter.float() + eps) / union.float()
return t
def calc_loss(pred, target, bce_weight=0.5):
bce = F.binary_cross_entropy_with_logits(pred, target)
pred = torch.sigmoid(pred)
dice = dice_coeff(pred, target)
loss = bce * bce_weight + dice * (1 - bce_weight)
return loss
#Dataset class
class img_label(Dataset):
def __init__(self, img_dpath, mask_dpath):
self.img_dpath = img_dpath
self.mask_dpath = mask_dpath
self.img_ids = [file for file in glob(img_dpath + '\\*.png')]
self.mask_ids = [file for file in glob(mask_dpath + '\\*.png')]
@classmethod
def preprocess(cls,PIL_img):
w,h = PIL_img.size #w = 512, h = 496
add_h = int((572 - h)/2)
add_w = int((572 - w)/2)
x = np.asarray(PIL_img)
x = np.pad(x,((add_h,add_h),(add_w,add_w)),mode='constant')
pic_mean = np.mean(x)
pic_std = np.std(x)
x = (x-pic_mean)/pic_std
return x
def __getitem__(self,idx):
imgpath = self.img_ids[idx]
maskpath = self.mask_ids[idx]
#note: the images need to be opened by PIL.Image.open(path) in order for it to be preprocessed.
img = Image.open(imgpath)
img = self.preprocess(img)
mask = Image.open(maskpath)
mask = self.preprocess(mask)
return (torch.from_numpy(img).type(torch.FloatTensor).unsqueeze(0).cuda(), torch.from_numpy(mask).type(torch.FloatTensor).unsqueeze(0).cuda())
def __len__(self):
return len(self.img_ids)
#Create dataset
dataset = img_label(image_dpath,mask_dpath)
#Create train and validation dataloaders
train,val = random_split(dataset,[1,1])
batch_size=1
train_loader = DataLoader(train, batch_size=batch_size,shuffle=True,num_workers=0)
val_loader = DataLoader(val,batch_size=batch_size,shuffle=True,num_workers=0)
# hyperparameters
n_class = 1 #we are predicting 1 class
learning_rate = 1e-5
optimizer = optim.Adam(Unet(n_class=n_class).parameters(),lr=learning_rate)
model = Unet(n_class=n_class).to(device=device)
# training loop
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader,val_loader):
for epoch in range(1,n_epochs+1):
loss_train = 0.0
a=0
for imgs,masks in train_loader:
output = model(imgs)
target_tensor = torch.randn((batch_size,1,388,388))
target = crop_img(masks,target_tensor)
optimizer.zero_grad()
loss = calc_loss(output,target)
loss.backward()
optimizer.step()
loss_train += loss.item()
trans = transforms.ToPILImage()
fig = plt.figure()
plt.imshow(trans(output.cpu().squeeze()))
fig.savefig(f'{a}.png')
a+=1
if epoch <= 100:
print(f"Epoch: {epoch}, Training loss: {loss_train/len(train_loader)}")
training_loop(n_epochs=100,optimizer=optimizer,model=model,loss_fn=dice_coeff,train_loader=train_loader,val_loader=val_loader)