Hey everyone,
I am new to Pytorch and facing the following problem:
To practice using Pytorch, I tried implementing my own U-Net. I have followed a Youtube tutorial on how to achieve this and tried to then code everything myself.
It seems like my code works mostly, but I have problems getting my U-Net Architecture to work properly.
When I run my code, the segmentation predictions are all black (with dice-score=0). When I copy-paste someone else’s architecture into my code, everything works well. Can anyone spot the error? I guess I just haven’t properly understood how architectures are defined…
Any help would be greatly appreciated!
My Architecture is defined like so:
import torch
import torch.nn as nn
#Conv -> Conv
class Double(nn.Module):
def __init__(self, in_channels, out_channels):
super(Double, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels,3,1,1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
#Downsampling -> Double
class Down(nn.Module):
def __init__(self):
super(Down, self).__init__()
self.down = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2)
)
def forward(self, x):
return self.down(x)
class Up(nn.Module):
def __init__(self, in_channels, out_channels):
super(Up, self).__init__()
self.up = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
)
def forward(self, x, y):
x = self.up(x)
return torch.cat([x, y], dim=1)
class Final(nn.Module):
def __init__(self, in_channels, out_channels):
super(Final, self).__init__()
self.final = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.final(x)
class UNET(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNET, self).__init__()
self.L1 = Double(in_channels, 16)
self.dow = Down()
self.L2 = Double(16, 32)
self.L3 = Double(32,64)
self.L4 = Double(64, 128)
self.L5 = Double(128, 256)
self.u6 = Up(256, 128)
self.L6 = Double(256, 128)
self.u7 = Up(128, 64)
self.L7 = Double(128, 64)
self.u8 = Up(64, 32)
self.L8 = Double(64, 32)
self.u9 = Up(32, 16)
self.L9 = Double(32, 16)
self.OUT = Final(16, out_channels)
def forward(self, x):
#Contraction
C1 = self.L1(x)
C2 = self.dow(C1)
C2 = self.L2(C2)
C3 = self.dow(C2)
C3 = self.L3(C3)
C4 = self.dow(C3)
C4 = self.L4(C4)
C5 = self.dow(C4)
C5 = self.L5(C5)
#Expansion
C6 = self.u6(C5, C4)
C6 = self.L6(C6)
C7 = self.u7(C6, C3)
C7 = self.L7(C7)
C8 = self.u8(C7, C2)
C8 = self.L8(C8)
C9 = self.u9(C8, C1)
C9 = self.L9(C9)
return self.OUT(C9)