RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 1536, 32, 32] to have 1024 channels, but got 1536 channels instead

I am trying to implement UNET structure and I am facing this error.

RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 1536, 32, 32] to have 1024 channels, but got 1536 channels instead

I have two files. One is unet_layer and another is unet_model.
The unet_layer is:

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import torch.nn.functional as F

class ThreeConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ThreeConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class DownStage(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),  # Max-pooling operation with kernel size 2x2 and stride 2 for down-sampling
            ThreeConvBlock(in_channels, out_channels)  # Applying the ThreeConvBlock for down-sampling
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class UpStage(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        if bilinear:
            # Bilinear upsampling with scale factor 2 and mode 'bilinear' for up-sampling
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = ThreeConvBlock(in_channels, out_channels)  # Applying the ThreeConvBlock for up-sampling
        else:
            # Transposed convolution (deconvolution) with kernel size 2x2 and stride 2 for up-sampling
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = ThreeConvBlock(in_channels, out_channels)  # Applying the ThreeConvBlock for up-sampling
        
    def forward(self, x1, x2):
        x1 = self.up(x1)

        # Adjusting the spatial dimensions of x1 to match x2
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        # Padding x1 to match the spatial size of x2 before concatenation
        x1 = F.pad(x1, 
                [diffX // 2 , diffX - diffX //2,
                 diffY // 2, diffY - diffY // 2
                 ]
                )
        x = torch.cat([x2, x1], dim=1)  # Concatenating the upsampled feature map (x1) with the skip connection (x2)
        return self.conv(x)  # Applying ThreeConvBlock to the concatenated feature map

And the unet_model is:

from unet_layer import *
import torch.nn as nn

class UNET(nn.Module):
    def __init__(self, n_channels, n_classes=1):
        super(UNET, self).__init__()
        self.nclasses = n_classes
        self.n_channels = n_channels

        # Define the initial block (ThreeConvBlock) of the UNET
        self.start = ThreeConvBlock(n_channels, 64)

        # Define the down-sampling stages (DownStage) of the UNET
        self.down1 = DownStage(64, 128)
        self.down2 = DownStage(128, 256)
        self.down3 = DownStage(256, 512)
        self.down4 = DownStage(512, 1024)

        # Define the up-sampling stages (UpStage) of the UNET
        self.up1 = UpStage(1024, 512)
        self.up2 = UpStage(512, 256)
        self.up3 = UpStage(256, 128)
        self.up4 = UpStage(128, 64)

        # Output layer (final convolution) for producing segmentation masks
        self.out = nn.Conv2d(64, self.nclasses, kernel_size=1)

    def forward(self, x):
        # Forward pass through the UNET architecture

        # Initial block
        x_start = self.start(x)

        # Down-sampling stages
        x_down1 = self.down1(x_start)
        x_down2 = self.down2(x_down1)
        x_down3 = self.down3(x_down2)
        x_down4 = self.down4(x_down3)
        print(x_down4.shape)

        # Up-sampling stages with skip connections
        x_up1 = self.up1(x_down4, x_down3)

        x_up2 = self.up2(x_up1, x_down2)
        x_up3 = self.up3(x_up2, x_down1)
        x_up4 = self.up4(x_up3, x_start)

        # Final output layer (convolution) for producing segmentation masks
        x_out = self.out(x_up4)
        return x_out

If i try to run this simply:

    input_data = torch.randn(1, 3, 256, 256)  # Batch size of 1, 3 input channels, and 256x256 resolution
    output_masks = model(input_data.to(device))
    print(output_masks)

I get that error.
Can anyone pinpoint the architectural problem. I have tried so hard but not able to do this.
Thank you.

The error is most likely raised in:

x_up1 = self.up1(x_down4, x_down3)

as you are concatenating both activations to 1536 channels while the internal conv layer expects an input with 1024 channels. Fix this by changing the in_channels value of the conv layer.