I am trying to implement UNET structure and I am facing this error.
RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 1536, 32, 32] to have 1024 channels, but got 1536 channels instead
I have two files. One is unet_layer and another is unet_model.
The unet_layer is:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import torch.nn.functional as F
class ThreeConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ThreeConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class DownStage(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2), # Max-pooling operation with kernel size 2x2 and stride 2 for down-sampling
ThreeConvBlock(in_channels, out_channels) # Applying the ThreeConvBlock for down-sampling
)
def forward(self, x):
return self.maxpool_conv(x)
class UpStage(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
if bilinear:
# Bilinear upsampling with scale factor 2 and mode 'bilinear' for up-sampling
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = ThreeConvBlock(in_channels, out_channels) # Applying the ThreeConvBlock for up-sampling
else:
# Transposed convolution (deconvolution) with kernel size 2x2 and stride 2 for up-sampling
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = ThreeConvBlock(in_channels, out_channels) # Applying the ThreeConvBlock for up-sampling
def forward(self, x1, x2):
x1 = self.up(x1)
# Adjusting the spatial dimensions of x1 to match x2
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
# Padding x1 to match the spatial size of x2 before concatenation
x1 = F.pad(x1,
[diffX // 2 , diffX - diffX //2,
diffY // 2, diffY - diffY // 2
]
)
x = torch.cat([x2, x1], dim=1) # Concatenating the upsampled feature map (x1) with the skip connection (x2)
return self.conv(x) # Applying ThreeConvBlock to the concatenated feature map
And the unet_model is:
from unet_layer import *
import torch.nn as nn
class UNET(nn.Module):
def __init__(self, n_channels, n_classes=1):
super(UNET, self).__init__()
self.nclasses = n_classes
self.n_channels = n_channels
# Define the initial block (ThreeConvBlock) of the UNET
self.start = ThreeConvBlock(n_channels, 64)
# Define the down-sampling stages (DownStage) of the UNET
self.down1 = DownStage(64, 128)
self.down2 = DownStage(128, 256)
self.down3 = DownStage(256, 512)
self.down4 = DownStage(512, 1024)
# Define the up-sampling stages (UpStage) of the UNET
self.up1 = UpStage(1024, 512)
self.up2 = UpStage(512, 256)
self.up3 = UpStage(256, 128)
self.up4 = UpStage(128, 64)
# Output layer (final convolution) for producing segmentation masks
self.out = nn.Conv2d(64, self.nclasses, kernel_size=1)
def forward(self, x):
# Forward pass through the UNET architecture
# Initial block
x_start = self.start(x)
# Down-sampling stages
x_down1 = self.down1(x_start)
x_down2 = self.down2(x_down1)
x_down3 = self.down3(x_down2)
x_down4 = self.down4(x_down3)
print(x_down4.shape)
# Up-sampling stages with skip connections
x_up1 = self.up1(x_down4, x_down3)
x_up2 = self.up2(x_up1, x_down2)
x_up3 = self.up3(x_up2, x_down1)
x_up4 = self.up4(x_up3, x_start)
# Final output layer (convolution) for producing segmentation masks
x_out = self.out(x_up4)
return x_out
If i try to run this simply:
input_data = torch.randn(1, 3, 256, 256) # Batch size of 1, 3 input channels, and 256x256 resolution
output_masks = model(input_data.to(device))
print(output_masks)
I get that error.
Can anyone pinpoint the architectural problem. I have tried so hard but not able to do this.
Thank you.