RuntimeError: Given groups=1, weight of size [256, 1, 3, 3], expected input[4, 3, 256, 256] to have 1 channels, but got 3 channels instead

anyone can correct the mistake why i am getting this error. am working on greyscale images have 5 classes and all the images have same size 256*256. the code are giving below. need guidance anyone?

class SelfAttention(nn.Module):
def init(self, channels, size):
super(SelfAttention, self).init()
self.channels = channels
self.size = size
self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)

def forward(self, x):
    x = x.view(-1, self.channels, self.size * self.size).transpose(1, 2)
    x_ln = self.ln(x)
    attention_value, _ = self.mha(x_ln, x_ln, x_ln)
    attention_value = attention_value + x
    attention_value = self.ff_self(attention_value) + attention_value
    return attention_value.transpose(1, 2).view(-1, self.channels, self.size, self.size)

class DoubleConv(nn.Module):
def init(self, in_channels, out_channels, mid_channels=None, residual=False):
super().init()
self.residual = residual
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, mid_channels),
nn.GELU(),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, out_channels),
)

def forward(self, x):
    if self.residual:
        return F.gelu(x + self.double_conv(x))
    else:
        return self.double_conv(x)

class Down(nn.Module):
def init(self, in_channels, out_channels, emb_dim=256):
super().init()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels),
)

    self.emb_layer = nn.Sequential(
        nn.SiLU(),
        nn.Linear(
            emb_dim,
            out_channels
        ),
    )

def forward(self, x, t):
    x = self.maxpool_conv(x)
    emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
    return x + emb

class Up(nn.Module):
def init(self, in_channels, out_channels, emb_dim=256):
super().init()

    self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
    self.conv = nn.Sequential(
        DoubleConv(in_channels, in_channels, residual=True),
        DoubleConv(in_channels, out_channels, in_channels // 2),
    )

    self.emb_layer = nn.Sequential(
        nn.SiLU(),
        nn.Linear(
            emb_dim,
            out_channels
        ),
    )

def forward(self, x, skip_x, t):
    x = self.up(x)
    x = torch.cat([skip_x, x], dim=1)
    x = self.conv(x)
    emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
    return x + emb

class UNet(nn.Module):
def init(self, c_in=1, c_out=1, time_dim=256, device=“cuda”):
super(UNet, self).init()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 256)
self.down1 = Down(256, 512)
self.sa1 = SelfAttention(512, 128)
self.down2 = Down(512, 1024)
self.sa2 = SelfAttention(1024, 64)
self.down3 = Down(1024, 2048)
self.sa3 = SelfAttention(2048, 32)

    self.bot1 = DoubleConv(256, 512)
    self.bot2 = DoubleConv(512, 512) 
    self.bot3 = DoubleConv(512, 256)

    self.up1 = Up(2048,1024) 
    self.sa4 = SelfAttention(1024,64) 
    self.up2 = Up(1024, 512) #256, 64
    self.sa5 = SelfAttention(512, 128) 
    self.up3 = Up(512,256)  #128, 64
    self.sa6 = SelfAttention(256,256)  
    self.outc = nn.Conv2d(256, c_out, kernel_size=1) 

def pos_encoding(self, t, channels):
    inv_freq = 1.0 / (
        10000
        ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
    )
    pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
    pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
    pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
    return pos_enc

def forward(self, x, t):   #take i/p the noise iamges amd the time steps
    t = t.unsqueeze(-1).type(torch.float)
    t = self.pos_encoding(t, self.time_dim)

    x1 = self.inc(x)
    x2 = self.down1(x1, t)
    x2 = self.sa1(x2)
    x3 = self.down2(x2, t)
    x3 = self.sa2(x3)
    x4 = self.down3(x3, t)
    x4 = self.sa3(x4)

    x4 = self.bot1(x4)
    x4 = self.bot2(x4)
    x4 = self.bot3(x4)

    x = self.up1(x4, x3, t)
    x = self.sa4(x)
    x = self.up2(x, x2, t)
    x = self.sa5(x)
    x = self.up3(x, x1, t)
    x = self.sa6(x)
    output = self.outc(x)
    return output

Can you copy the actual error?

Also, please show where you instantiate the model. For example model = UNet(c_1 = 1, ...)

1 Like