U_net loss function

Hi, I am having some trouble in training a U-Net. I have implemented U-Net in keras before and am trying to do the same with pytorch. The problem is my U-Net in Pytroch doesn’t seem to be learning. The train loss remains well under 0.0005 which is terrible. For my keras Unet, the train loss improves drastically(compared to Pytorch) from the second epoch. I am assuming that something is wrong with either my loss function or metric function. Can anyone help me out with what is wrong here? thanks in advance

#Loss function

class DiceCoefLoss(nn.Module):
    def __init__(self):
        super(DiceCoefLoss,self).__init__()

    def forward(self,input,target):
        input, target = input.cuda(),target.cuda()
        smooth =1

        iflat = input.view(-1)
        tflat = target.view(-1)
        intersection = (iflat * tflat).sum()
        dice_coef= (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)
        print("DICE_COEF loss", -dice_coef)
        return -dice_coef

#Metric function

def dice_coef(output, target):
    smooth = 1
    output, target = output.cpu(),target.cpu()
    iflat = torch.flatten(output)
    tflat = torch.flatten(target)
    intersection = (iflat * tflat).sum()

    return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)
        

def dice_coef_loss(output,target):
    return - dice_coef(output,target)

#Dataset 
class MyLidcDataset(Dataset):
    def __init__(self, IMAGES_PATHS, MASK_PATHS):
        """
        IMAGES_PATHS: list of images paths ['./Images/0001_01_images.npy','./Images/0001_02_images.npy']
        MASKS_PATHS: list of masks paths ['./Masks/0001_01_masks.npy','./Masks/0001_02_masks.npy']
        """
        self.image_paths = IMAGES_PATHS
        self.mask_paths= MASK_PATHS

    def transform(self, image, mask):
        #Transform to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        image,mask = image.float(), mask.float()
        return image,mask

    def __getitem__(self, index):
        image = np.load(self.image_paths[index])
        mask = np.load(self.mask_paths[index])
        image,mask = self.transform(image,mask)
        return image,mask
       
    def __len__(self):
        return len(self.image_paths)

What values do your mask tensors contain and how does the loss curve look for your current setup?

Hello, my tensor datatypes are as follow
input: float
output= model(input) : float
target : boolean

I changed my loss function as below which seems to be right. Right now I feel like its my U-Net model thats making all the trouble

class DiceCoefLoss(nn.Module):
    def __init__(self):
        super(DiceCoefLoss,self).__init__()

    def forward(self,input,target):
        input, target = input.cuda(),target.cuda()
        smooth =1

        iflat = input.view(-1)
        tflat = target.view(-1)
        intersection = (iflat * tflat).sum()
        dice_coef= (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)
        return 1-dice_coef

My U-Net model structure is as below


class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, act_func=nn.ReLU(inplace=True)):
        super(VGGBlock, self).__init__()
        self.act_func = act_func
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act_func(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.act_func(out)

        return out


class UNet(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.args = args

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = VGGBlock(args['input_channels'], nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], 1, kernel_size=1)


    def forward(self, input):
        x0_0 = self.conv0_0(input)
        print(x0_0.shape)
        x1_0 = self.conv1_0(self.pool(x0_0))
        print(x1_0.shape)
        x2_0 = self.conv2_0(self.pool(x1_0))
        print(x2_0.shape)
        x3_0 = self.conv3_0(self.pool(x2_0))
        print(x3_0.shape)
        x4_0 = self.conv4_0(self.pool(x3_0))
        print(x4_0.shape)
        print(" ")
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        print(x3_1.shape)
        x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
        print(x2_2.shape)
        x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
        print(x1_3.shape)
        x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
        print(x0_4.shape)

        output = self.final(x0_4)
        output = F.sigmoid(output)


        print("THE OUTPUT IS:" ,output.shape)
        return output

When ever I print the shape of each line, my console prints the same line 4 times.
Is there something possibly wrong with the architecture? I am using a batch size of 8 with an input channel of 1.

torch.Size([2, 32, 512, 512])
torch.Size([2, 32, 512, 512])
torch.Size([2, 32, 512, 512])
torch.Size([2, 32, 512, 512])
torch.Size([2, 64, 256, 256])
torch.Size([2, 64, 256, 256])
torch.Size([2, 64, 256, 256])
torch.Size([2, 64, 256, 256])
torch.Size([2, 128, 128, 128])
torch.Size([2, 128, 128, 128])
torch.Size([2, 128, 128, 128])
torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64])
torch.Size([2, 256, 64, 64])
torch.Size([2, 256, 64, 64])
torch.Size([2, 256, 64, 64])
torch.Size([2, 512, 32, 32])
torch.Size([2, 512, 32, 32])

torch.Size([2, 512, 32, 32])

torch.Size([2, 512, 32, 32])

torch.Size([2, 256, 64, 64])
torch.Size([2, 256, 64, 64])
torch.Size([2, 256, 64, 64])
torch.Size([2, 256, 64, 64])
torch.Size([2, 128, 128, 128])
torch.Size([2, 128, 128, 128])
torch.Size([2, 128, 128, 128])
torch.Size([2, 128, 128, 128])
torch.Size([2, 64, 256, 256])
torch.Size([2, 64, 256, 256])
torch.Size([2, 64, 256, 256])
torch.Size([2, 64, 256, 256])
torch.Size([2, 32, 512, 512])

my loss doesn’t decrease from 0.99 right now

okay, I found out that this repetition is because I use 4 gpus. when I change to 1, the results are what I was expecting

Hi @Jay_Super,

Is your mask binary or do you have a multi-class segmentation problem?
After Loading, what values are in your mask tensor, what values are in your images? Do you normalize the images?