RuntimeError: Given groups=1, weight of size [512, 1536, 3, 3], expected input[1, 1024, 32, 32] to have 1536 channels, but got 1024 channels instead

I encountered this error while running UNet. My code is as follows. What is the issue? My input image is 256 x 256.

from unet_parts import *

class UNet(nn.Module):
def init(self, n_channels=1, n_classes=1, bilinear=False):
super(UNet, self).init()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear

    self.inc = (DoubleConv(n_channels, 64))
    self.down1 = (Down(64, 128))
    self.down2 = (Down(128, 256))
    self.down3 = (Down(256, 512))
    factor = 2 if bilinear else 1
    self.down4 = (Down(512, 1024 // factor))
    self.up1 = (Up(1024, 512 // factor, bilinear))
    self.up2 = (Up(512, 256 // factor, bilinear))
    self.up3 = (Up(256, 128 // factor, bilinear))
    self.up4 = (Up(128, 64, bilinear))
    self.outc1 = (OutConv(64, n_classes))
    self.outc2 = (OutConv(64, n_classes))
    self.outc3 = (OutConv(64, n_classes))
    self.outc4 = (OutConv(64, n_classes))


def forward(self, x):
    x1 = self.inc(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)
    x = self.up1(x5, x4)
    x = self.up2(x, x3)
    x = self.up3(x, x2)
    x = self.up4(x, x1)
    fake_bin1 = self.outc1(x)
    fake_bin2 = self.outc2(x)
    fake_bin3 = self.outc3(x)
    fake_bin4 = self.outc4(x)

    return fake_bin1, fake_bin2, fake_bin3, fake_bin4

def use_checkpointing(self):
    self.inc = torch.utils.checkpoint(self.inc)
    self.down1 = torch.utils.checkpoint(self.down1)
    self.down2 = torch.utils.checkpoint(self.down2)
    self.down3 = torch.utils.checkpoint(self.down3)
    self.down4 = torch.utils.checkpoint(self.down4)
    self.up1 = torch.utils.checkpoint(self.up1)
    self.up2 = torch.utils.checkpoint(self.up2)
    self.up3 = torch.utils.checkpoint(self.up3)
    self.up4 = torch.utils.checkpoint(self.up4)
    self.outc1 = torch.utils.checkpoint(self.outc1)
    self.outc2 = torch.utils.checkpoint(self.outc2)
    self.outc3 = torch.utils.checkpoint(self.outc3)
    self.outc4 = torch.utils.checkpoint(self.outc4)

Add channel 1 to your input image. your input should be (1, 256, 256) if your image is grayscale and if you are working in PyTorch.