I am getting this channel mismacth error RuntimeError: Given groups=1, weight of size [64, 2, 3, 3], expected input[1, 1, 256, 256] to have 2 channels, but got 1 channels instead
in_channels = 2 but still i cant uderstand the cause of error.
Here is the segmentation model
class SegmentationModel(nn.Module):
def __init__(self, in_channels):
super(SegmentationModel, self).__init__()
self.conv1 = nn.Conv2d(2, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.final_conv = nn.Conv2d(64, 2, kernel_size=1)
def forward(self, x):
# Assuming x is of shape [batch_size, channels, height, width]
x1 = F.relu(self.conv1(x))
x2 = F.relu(self.conv2(x1))
x3 = F.relu(self.conv3(x2))
x4 = F.relu(self.conv4(x3))
x5 = F.relu(self.conv5(x4))
x6 = F.relu(self.upconv1(x5))
x7 = F.relu(self.upconv2(x6))
x8 = F.relu(self.upconv3(x7))
x9 = F.relu(self.upconv4(x8))
output = self.final_conv(x9)
return output
Here are the input output shapes:
Image shape: torch.Size([2, 256, 256])
Mask shape: torch.Size([2, 256, 256])
Output shape: torch.Size([1, 2, 4096, 4096])
Resized output shape: torch.Size([1, 2, 256, 256])
Resized_Image shape: torch.Size([1, 2, 256, 256])
Resized_Mask shape: torch.Size([1, 2, 256, 256])