Pytorch outputs black image as segmentation mask

Hey everyone,

I am new to Pytorch and facing the following problem:

To practice using Pytorch, I tried implementing my own U-Net. I have followed a Youtube tutorial on how to achieve this and tried to then code everything myself.

It seems like my code works mostly, but I have problems getting my U-Net Architecture to work properly.

When I run my code, the segmentation predictions are all black (with dice-score=0). When I copy-paste someone else’s architecture into my code, everything works well. Can anyone spot the error? I guess I just haven’t properly understood how architectures are defined…

Any help would be greatly appreciated!

My Architecture is defined like so:

import torch
import torch.nn as nn

#Conv -> Conv
class Double(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Double, self).__init__()
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,3,1,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels,3,1,1),
            nn.ReLU(inplace=True)
            )
    
    def forward(self, x):
        return self.double_conv(x)

#Downsampling -> Double
class Down(nn.Module):
    def __init__(self):
        super(Down, self).__init__()
        
        self.down = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2)
            )
    
    def forward(self, x):
        return self.down(x)
    

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        
        self.up = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
            
            )
        
    def forward(self, x, y):
        
        x = self.up(x)
        return torch.cat([x, y], dim=1)
    
class Final(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Final, self).__init__()
        
        self.final = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        return self.final(x)

class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNET, self).__init__()
        
       
        self.L1 = Double(in_channels, 16)
        self.dow = Down()
        self.L2 = Double(16, 32)
        self.L3 = Double(32,64)
        self.L4 = Double(64, 128)
        self.L5 = Double(128, 256)
        self.u6 = Up(256, 128)
        self.L6 = Double(256, 128)
        self.u7 = Up(128, 64)
        self.L7 = Double(128, 64)
        self.u8 = Up(64, 32)
        self.L8 = Double(64, 32)
        self.u9 = Up(32, 16)
        self.L9 = Double(32, 16)
        self.OUT = Final(16, out_channels)
        
    def forward(self, x):
        #Contraction
        C1 = self.L1(x)
        
        C2 = self.dow(C1)
        C2 = self.L2(C2)
        
        C3 = self.dow(C2)
        C3 = self.L3(C3)
        
        C4 = self.dow(C3)
        C4 = self.L4(C4)
        
        C5 = self.dow(C4)
        C5 = self.L5(C5)
        
        #Expansion
        C6 = self.u6(C5, C4)
        C6 = self.L6(C6)
        
        C7 = self.u7(C6, C3)
        C7 = self.L7(C7)
        
        C8 = self.u8(C7, C2)
        C8 = self.L8(C8)
        
        C9 = self.u9(C8, C1)
        C9 = self.L9(C9)
        
        return self.OUT(C9)

Can u share ur dice coefficient implementation ?

I use loss = nn.BCEWithLogitsLoss() as my loss function - the Dice score really is just a temporary thing I calculate to see how my network is doing. But anyways, here is if:

def accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print("Dice-score: ", dice_score)

    model.train()

When I insert another architecture of UNet (for example this one), the network works…

Okay, I got it to work…

I added a batch normalization layer before the ReLU operations in the “Double”-Class. Meaning I changed

class Double(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Double, self).__init__()
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,3,1,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels,3,1,1),
            nn.ReLU(inplace=True)
            )
    
    def forward(self, x):
        return self.double_conv(x)

to:

class Double(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Double, self).__init__()
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,3,1,1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels,3,1,1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
    
    def forward(self, x):
        return self.double_conv(x)

And it does work now. But I wonder why it didn’t work before. When U-Net was developed, batch normalization wasn’t yet a thing, wasn’t it? So how did they get it to work?