Runtime error: Given groups=1, weight of size [64, 4, 3, 3], expected input[32, 256, 256, 4] to have 4 channels, but got 256 channels instead

I have the input feature 256 * 256 * 4 channels and I have given batch_size as 32 and the output_label_size is 256 * 256 * 1. The below is the code for U-Net architecture.
class unet(nn.Module):
def init(self):
super(unet, self).init()

#encoder
self.e11 = nn.Conv2d(4, 64, kernel_size=3, padding=0)
self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=0)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding=0)
self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding=0)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding=0)
self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding=0)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding=0)
self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding=0)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding=0)
self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=0)

#decoder
self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding=0)
self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding=0)

self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.d21 = nn.Conv2d(512, 256, kernel_size=3, padding=0)
self.d22 = nn.Conv2d(256, 256, kernel_size=3, padding=0)

self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.d31 = nn.Conv2d(256, 128, kernel_size=3, padding=0)
self.d32 = nn.Conv2d(128, 128, kernel_size=3, padding=0)

self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding=0)
self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding=0)

#output layer
self.outconv = nn.Conv2d(64, 1, kernel_size=1)

def forward(self, x):
#encoder
xe11 = F.relu(self.e11(x))
xe12 = F.relu(self.e12(xe11))
xpool1 = self.pool1(xe12)

xe21 = F.relu(self.e21(xpool1))
xe22 = F.relu(self.e22(xe21))
xpool2 = self.pool2(xe22)

xe31 = F.relu(self.e31(xpool2))
xe32 = F.relu(self.e32(xe31))
xpool3 = self.pool3(xe32)

xe41 = F.relu(self.e41(xpool3))
xe42 = F.relu(self.e42(xe41))
xpool4 = self.pool4(xe42)

xe51 = F.relu(self.e51(xpool4))
xe52 = F.relu(self.e52(xe51))

#decoder
xu1 = self.upconv1(xe52)
xu11 = torch.cat((xu1, xe42), dim=1) #skip-connection
xd11 = F.relu(self.d11(xu11))
xd12 = F.relu(self.d12(xd11))

xu2 = self.upconv2(xd12)
xu22 = torch.cat((xu2, xe32), dim=1)
xd21 = F.relu(self.d21(xu22))
xd22 = F.relu(self.d22(xd21))

xu3 = self.upconv3(xd22)
xu33 = torch.cat((xu3, xe22), dim=1)
xd31 = F.relu(self.d31(xu33))
xd32 = F.relu(self.d32(xd31))

xu4 = self.upconv4(xd32)
xu44 = torch.cat((xu4, xe12), dim=1)
xd41 = F.relu(self.d41(xu44))
xd42 = F.relu(self.d42(xd41))

#output layer
out = self.outconv(xd42)

return out. While training,

#Training loop
epochs = 10

for epoch in range(epochs):
model.train()
train_loss = 0.0

for i, (features, labels) in enumerate(train_loader):
features = features.to(device)
labels = labels.to(device)

optimizer.zero_grad()

#forward pass
outputs = model(features)

#calculate loss
loss = loss_fn(outputs, labels)

loss.backward()

#update weights
optimizer.step()

train_loss += loss.item() * features.size(0)

train_loss /= len(train_loader.dataset)

print(f’Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}'). I am getting the below error.

RuntimeError Traceback (most recent call last)
in <cell line: 4>()
13
14 #forward pass
—> 15 outputs = model(features)
16
17 #calculate loss

6 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
454 weight, bias, self.stride,
455 _pair(0), self.dilation, self.groups)
→ 456 return F.conv2d(input, weight, bias, self.stride,
457 self.padding, self.dilation, self.groups)
458

RuntimeError: Given groups=1, weight of size [64, 4, 3, 3], expected input[32, 256, 256, 4] to have 4 channels, but got 256 channels instead. Kindly provide me solutions.

It seems your input is stored in the channels-last memory format while PyTorch expects channels-first. permute the tensor and it should work.

1 Like