Right now, I’m trying to implement U-net for image inpainting. I got my code to run without error, and when printing out train loss and validation loss, it does kind of converges to some point. Images does seem to resemble the ground truth, but it is quite poor in quality. When I look at the pixel values of inpainted part, it is far from converging to ground truth values. It seems as if it is randomly changing its values.
On the other hand, I was able to do the identical task with Matconvnet, a MATLAB library, so I guess it should work in pytorch as well.
So the problem is I think I have some problem with implementing U-net with pytorch. However, I don’t know where I’m wrong with implementing, since I do not get any errors. Below is my implementation.
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNetConvBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
super(UNetConvBlock, self).__init__()
self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
self.bn = nn.BatchNorm2d(out_size)
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
self.bn2 = nn.BatchNorm2d(out_size)
self.activation = activation
def forward(self, x):
x1 = self.activation(self.bn(self.conv(x)))
out = self.activation(self.bn2(self.conv2(x1)))
return out
class UNetLastBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
super(UNetLastBlock, self).__init__()
self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
self.bn = nn.BatchNorm2d(out_size)
self.conv2 = nn.Conv2d(out_size,in_size, kernel_size, padding=1)
self.bn2 = nn.BatchNorm2d(in_size)
self.activation = activation
def forward(self, x):
x1 = self.activation(self.bn(self.conv(x)))
out = self.activation(self.bn2(self.conv2(x1)))
return out
class UNetUpBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
super(UNetUpBlock, self).__init__()
self.up = nn.ConvTranspose2d(in_size, in_size, 2, stride=2)
# Due to concat
self.conv = nn.Conv2d(in_size * 2, in_size, kernel_size, padding=1)
self.bn = nn.BatchNorm2d(in_size)
self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
self.bn2 = nn.BatchNorm2d(out_size)
self.activation = activation
def forward(self, x, bridge):
up = self.up(x)
out = torch.cat([up, bridge], dim=1)
out = self.activation(self.bn(self.conv(out)))
out = self.activation(self.bn2(self.conv2(out)))
return out
class UNet(nn.Module):
def __init__(self, in_c, out_c):
super(UNet, self).__init__()
self.activation = F.relu
self.pool1 = nn.MaxPool2d(2)
self.pool2 = nn.MaxPool2d(2)
self.pool3 = nn.MaxPool2d(2)
self.pool4 = nn.MaxPool2d(2)
self.conv_block2_64 = UNetConvBlock(in_c, 64)
self.conv_block64_128 = UNetConvBlock(64, 128)
self.conv_block128_256 = UNetConvBlock(128, 256)
self.conv_block256_512 = UNetConvBlock(256, 512)
self.conv_block512_1024 = UNetLastBlock(512, 1024)
self.up_block1024_512 = UNetUpBlock(512, 256)
self.up_block512_256 = UNetUpBlock(256, 128)
self.up_block256_128 = UNetUpBlock(128, 64)
self.up_block128_64 = UNetUpBlock(64, 64)
self.last = nn.Conv2d(64, out_c, 1)
def forward(self, x):
block1 = self.conv_block2_64(x)
pool1 = self.pool1(block1)
block2 = self.conv_block64_128(pool1)
pool2 = self.pool2(block2)
block3 = self.conv_block128_256(pool2)
pool3 = self.pool3(block3)
block4 = self.conv_block256_512(pool3)
pool4 = self.pool4(block4)
block5 = self.conv_block512_1024(pool4)
up1 = self.up_block1024_512(block5, block4)
up2 = self.up_block512_256(up1, block3)
up3 = self.up_block256_128(up2, block2)
up4 = self.up_block128_64(up3, block1)
return self.last(up4)
Can anyone tell me if any part of this is wrong? Moreover, is there some rule of thumbs to follow to check if my model is wrong?
Thank you in advance.