Hello! I am relatively new to the topic and I am trying to implement a resnet101 encoder with U-Net decoder. First of all, I am unsure whether I should do encoding, then pooling and then perform upsampling in order to concatenate the center with the conv5, or if I should skip this step and go straight to concatenating conv5 with conv4 in the decoding step? I am also not sure how should the sizes of the tensors be. I tried to achieve the following sizes in the decoding step but I don’t manage:
dec5: (4, 512, 32, 32)
dec4: (4, 256, 64, 64)
dec3: (4, 128, 128, 128)
dec2: (4, 64, 256, 256)
Is this even the correct approach?
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super(DoubleConv, self).__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super(Down, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels, mid_channels=mid_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None, bilinear=True, resnet_encoder=False):
super(Up, self).__init__()
self.resnet_encoder = resnet_encoder
if not mid_channels:
mid_channels = out_channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, mid_channels=mid_channels)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels, mid_channels=mid_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNetBase(nn.Module):
def __init__(self, num_classes, bottom_channel_nr=None, num_filters=32, bilinear=True, resnet_encoder=False):
super(UNetBase, self).__init__()
self.bilinear = bilinear
self.resnet_encoder = resnet_encoder
factor = 2 if bilinear else 1
if not self.resnet_encoder:
bottom_channel_nr = num_filters * 32
if resnet_encoder:
self.up1 = Up(bottom_channel_nr + num_filters * 32, num_filters * 16, num_filters * 16, bilinear, resnet_encoder=resnet_encoder)
self.up2 = Up((bottom_channel_nr // 2), num_filters * 8, num_filters * 8, bilinear, resnet_encoder=resnet_encoder)
self.up3 = Up((bottom_channel_nr // 4) + num_filters * 8, num_filters * 8, num_filters * 2, bilinear, resnet_encoder=resnet_encoder)
self.up4 = Up((bottom_channel_nr // 8) + num_filters * 2, num_filters * 4, num_filters * 4, bilinear, resnet_encoder=resnet_encoder)
else:
self.up1 = Up(num_filters * 32, num_filters * 16 // factor, bilinear=bilinear, resnet_encoder=resnet_encoder)
self.up2 = Up(num_filters * 16, num_filters * 8 // factor, bilinear=bilinear, resnet_encoder=resnet_encoder)
self.up3 = Up(num_filters * 8, num_filters * 4 // factor, bilinear=bilinear, resnet_encoder=resnet_encoder)
self.up4 = Up(num_filters * 4, num_filters * 2, bilinear=bilinear, resnet_encoder=resnet_encoder)
self.outc = OutConv(num_filters * 2, num_classes)
def decode(self, features):
x1, x2, x3, x4, x5 = features
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
class UNetResNet(UNetBase):
def __init__(self, encoder_depth, num_classes, bilinear=True, pretrained=True):
if encoder_depth == 34:
bottom_channel_nr = 512
elif encoder_depth in [101, 152]:
bottom_channel_nr = 2048
else:
raise NotImplementedError('ResNet encoder_depth should be 34, 101, or 152')
super(UNetResNet, self).__init__(num_classes, bottom_channel_nr, bilinear=bilinear, resnet_encoder=True)
self.encoder = models.__dict__[f'resnet{encoder_depth}'](pretrained=pretrained)
self.pool = nn.MaxPool2d(2, 2)
self.conv1 = nn.Sequential(self.encoder.conv1, self.encoder.bn1, self.encoder.relu, self.pool)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
DoubleConv(bottom_channel_nr, 32 * 8)
)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x5 = self.conv5(x4)
pool = self.pool(x5)
center = self.center(pool)
return self.decode([x2, x3, x4, x5, center])
if __name__ == "__main__":
input_tensor = torch.randn(4, 3, 512, 512).cuda()
unet_resnet = UNetResNet(encoder_depth=101, num_classes=1, bilinear=True, pretrained=True).cuda()
output_resnet = unet_resnet(input_tensor)
print("UNetResNet Output shape:", output_resnet.shape)