Hello,
I’m currently trying to implement the Unet paper but I’m unable to implement it properly.
Here is the architecture of Unet:
The difficult part here is the cropping marked in red. How can I crop a batch of images from 136x136 to 104x104 in my neural net? I believe I can do this with numpy but that would mean that I will have to transfer them from the gpu to the cpu and iterate through each one of them to do the transformation. I would like to do the cropping on the GPU on all the images once.
So far I have this implementation:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
class ConvRelu2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
super(ConvRelu2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
class UNetOriginal(nn.Module):
def __init__(self, in_shape):
super(UNetOriginal, self).__init__()
channels, height, width = in_shape
self.down1 = nn.Sequential(
ConvRelu2d(channels, 64, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(64, 64, kernel_size=(3, 3), stride=1, padding=0)
)
self.maxPool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.down2 = nn.Sequential(
ConvRelu2d(64, 128, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(128, 128, kernel_size=(3, 3), stride=1, padding=0)
)
self.maxPool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.down3 = nn.Sequential(
ConvRelu2d(128, 256, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(256, 256, kernel_size=(3, 3), stride=1, padding=0)
)
self.maxPool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.down4 = nn.Sequential(
ConvRelu2d(256, 512, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(512, 512, kernel_size=(3, 3), stride=1, padding=0)
)
self.maxPool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.center = nn.Sequential(
ConvRelu2d(512, 1024, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(1024, 1024, kernel_size=(3, 3), stride=1, padding=0)
)
self.upSample1 = nn.Upsample(size=(1024, 1024), scale_factor=(2, 2), mode="bilinear")
self.up1 = nn.Sequential(
ConvRelu2d(1024, 512, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(512, 512, kernel_size=(3, 3), stride=1, padding=0)
)
self.upSample2 = nn.Upsample(size=(512, 512), scale_factor=(2, 2), mode="bilinear")
self.up2 = nn.Sequential(
ConvRelu2d(512, 256, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(256, 256, kernel_size=(3, 3), stride=1, padding=0)
)
self.upSample3 = nn.Upsample(size=(256, 256), scale_factor=(2, 2), mode="bilinear")
self.up3 = nn.Sequential(
ConvRelu2d(256, 128, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(128, 128, kernel_size=(3, 3), stride=1, padding=0)
)
self.upSample4 = nn.Upsample(size=(128, 128), scale_factor=(2, 2), mode="bilinear")
self.up4 = nn.Sequential(
ConvRelu2d(128, 64, kernel_size=(3, 3), stride=1, padding=0),
ConvRelu2d(64, 64, kernel_size=(3, 3), stride=1, padding=0)
)
# 1x1 convolution at the last layer
self.output_seg_map = nn.Conv2d(64, 2, kernel_size=(1, 1), padding=0, stride=1)
def _crop_concat(self, x, y):
"""
Crop y to the (h, w) of x and concat them.
Used for the expansive path.
Returns:
The concatenated tensor
"""
to_height = x.size()[-2]
to_width = x.size()[-1]
# TODO crop y
x = torch.cat([x, y], 1) # On Channels axis
return x
def forward(self, x):
x = self.down1(x) # Calls the forward() method of each layer
out_down1 = x
x = self.maxPool1(x)
x = self.down2(x)
out_down2 = x
x = self.maxPool2(x)
x = self.down3(x)
out_down3 = x
x = self.maxPool3(x)
x = self.down4(x)
out_down4 = x
x = self.maxPool4(x)
x = self.center(x)
x = self.upSample1(x)
x = self.up1(x)
self._crop_concat(x, out_down4)
x = self.upSample2(x)
x = self.up2(x)
self._crop_concat(x, out_down3)
x = self.upSample3(x)
x = self.up3(x)
self._crop_concat(x, out_down2)
x = self.upSample4(x)
x = self.up4(x)
self._crop_concat(x, out_down1)
out = self.output_seg_map(x)
return out
I ran into this but it’s not really what I want as it takes a Pillow image as input.
Thanks