Resnet101 encoder with U-Net decoder from scratch - tensor size issue

Hello! I am relatively new to the topic and I am trying to implement a resnet101 encoder with U-Net decoder. First of all, I am unsure whether I should do encoding, then pooling and then perform upsampling in order to concatenate the center with the conv5, or if I should skip this step and go straight to concatenating conv5 with conv4 in the decoding step? I am also not sure how should the sizes of the tensors be. I tried to achieve the following sizes in the decoding step but I don’t manage:
dec5: (4, 512, 32, 32)
dec4: (4, 256, 64, 64)
dec3: (4, 128, 128, 128)
dec2: (4, 64, 256, 256)
Is this even the correct approach?

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(DoubleConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels, mid_channels=mid_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, bilinear=True, resnet_encoder=False):
        super(Up, self).__init__()
        self.resnet_encoder = resnet_encoder
        if not mid_channels:
            mid_channels = out_channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, mid_channels=mid_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels, mid_channels=mid_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNetBase(nn.Module):
    def __init__(self, num_classes, bottom_channel_nr=None, num_filters=32, bilinear=True, resnet_encoder=False):
        super(UNetBase, self).__init__()
        self.bilinear = bilinear
        self.resnet_encoder = resnet_encoder
        factor = 2 if bilinear else 1

        if not self.resnet_encoder:
            bottom_channel_nr = num_filters * 32

        if resnet_encoder:
            self.up1 = Up(bottom_channel_nr + num_filters * 32, num_filters * 16, num_filters * 16, bilinear, resnet_encoder=resnet_encoder)
            self.up2 = Up((bottom_channel_nr // 2), num_filters * 8, num_filters * 8, bilinear, resnet_encoder=resnet_encoder)
            self.up3 = Up((bottom_channel_nr // 4) + num_filters * 8, num_filters * 8, num_filters * 2, bilinear, resnet_encoder=resnet_encoder)
            self.up4 = Up((bottom_channel_nr // 8) + num_filters * 2, num_filters * 4, num_filters * 4, bilinear, resnet_encoder=resnet_encoder)
        else:
            self.up1 = Up(num_filters * 32, num_filters * 16 // factor, bilinear=bilinear, resnet_encoder=resnet_encoder)
            self.up2 = Up(num_filters * 16, num_filters * 8 // factor, bilinear=bilinear, resnet_encoder=resnet_encoder)
            self.up3 = Up(num_filters * 8, num_filters * 4 // factor, bilinear=bilinear, resnet_encoder=resnet_encoder)
            self.up4 = Up(num_filters * 4, num_filters * 2, bilinear=bilinear, resnet_encoder=resnet_encoder)
        self.outc = OutConv(num_filters * 2, num_classes)

    def decode(self, features):
        x1, x2, x3, x4, x5 = features
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

class UNetResNet(UNetBase):
    def __init__(self, encoder_depth, num_classes, bilinear=True, pretrained=True):
        if encoder_depth == 34:
            bottom_channel_nr = 512
        elif encoder_depth in [101, 152]:
            bottom_channel_nr = 2048
        else:
            raise NotImplementedError('ResNet encoder_depth should be 34, 101, or 152')

        super(UNetResNet, self).__init__(num_classes, bottom_channel_nr, bilinear=bilinear, resnet_encoder=True)
        self.encoder = models.__dict__[f'resnet{encoder_depth}'](pretrained=pretrained)
        
        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = nn.Sequential(self.encoder.conv1, self.encoder.bn1, self.encoder.relu, self.pool)
        self.conv2 = self.encoder.layer1
        self.conv3 = self.encoder.layer2
        self.conv4 = self.encoder.layer3
        self.conv5 = self.encoder.layer4

        self.center = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 
            DoubleConv(bottom_channel_nr, 32 * 8) 
            )

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)
        pool = self.pool(x5)
        center = self.center(pool)
        return self.decode([x2, x3, x4, x5, center])

if __name__ == "__main__":
    input_tensor = torch.randn(4, 3, 512, 512).cuda()
    unet_resnet = UNetResNet(encoder_depth=101, num_classes=1, bilinear=True, pretrained=True).cuda()
    output_resnet = unet_resnet(input_tensor)
    print("UNetResNet Output shape:", output_resnet.shape)