RuntimeError: Given groups=1, weight[64, 3, 3, 3], so expected input[16, 64, 256, 256] to have 3 channels, but got 64 channels instead

Why am I getting this error?

RuntimeError: Given groups=1, weight[64, 3, 3, 3], so expected input[16, 64, 256, 256] to have 3 channels, but got 64 channels instead

I wrote an implementation of U-net.

class double_conv(nn.Module):
  def __init__(self, in_ch, out_ch):
    super(double_conv, self).__init__()
    self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
    self.conv2 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
    
  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    return x
  
  
class input_conv(nn.Module):
  def __init__(self, in_ch, out_ch):
    super(input_conv, self).__init__()
    self.inp_conv = double_conv(in_ch, out_ch)
    
  def forward(self, x):
    x = self.inp_conv(x)
    return x
  

class up(nn.Module):
  def __init__(self, in_ch, out_ch):
    super(up, self).__init__()
    self.up_conv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
    self.conv = double_conv(in_ch, out_ch)
    
  def forward(self, x1, x2):
    x1 = self.up_conv(x1)
    x = torch.cat([x2, x1], dim=1)
    x = self.conv(x)
    return x
    
    
class down(nn.Module):
  def __init__(self, in_ch, out_ch):
    super(down, self).__init__()
    self.pool = nn.MaxPool2d(2)
    self.conv = double_conv(in_ch, out_ch)
    
  def forward(self, x):
    x = self.pool(x)
    x = self.conv(x)
    return x
  
class last_conv(nn.Module):
  def __init__(self, in_ch, out_ch):
    super(last_conv, self).__init__()
    self.conv1 = nn.Conv2d(in_ch, out_ch, 1)
    
  def forward(self, x):
    x = self.conv1(x)
    return x


class Unet(nn.Module):
  def __init__(self, channels, classes):
    super(Unet, self).__init__()
    self.inp = input_conv(channels, 64)
    self.down1 = down(64, 128)
    self.down2 = down(128, 256)
    self.down3 = down(256, 512)
    self.down4 = down(512, 1024)
    self.up1 = up(1024, 512)
    self.up2 = up(512, 256)
    self.up3 = up(256, 128)
    self.up4 = up(128, 64)
    self.out = last_conv(64, classes)
    
  def forward(self, x):
    x1 = self.inp(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)
    x = self.up1(x5, x4)
    x = self.up2(x, x3)
    x = self.up3(x, x2)
    x = self.up1(x, x1)
    x = self.out(x)
    return x


model = Unet(3, 1)

This is the training loop

for epoch in range(5):
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs = Variable(inputs).cuda()
        labels = Variable(labels).cuda()
        
        # forward + backward + optimize
        
        # zeroes the gradient buffers of all parameters
        optimizer.zero_grad()
        #forward pass
        outputs = model_pytorch(inputs)
        # calculate the loss
        loss = loss_function(outputs, labels)
        # backpropagation
        loss.backward()
        # Does the update after calculating the gradients
        optimizer.step()
        
        if (i+1) % 5 == 0: # print every 100 mini-batches
            print('[%d, %5d] loss: %.4f' % (epoch, i+1, loss.data[0]))
3 Likes

It means your input should have 3 channels , but you give a 64 channels input. The input are organized in [N, C, W, H] format, your input, also data layer, should have 3 channels. You should check your code.

2 Likes

My input does have 3 channels. The input to the Unet is 3, 1 which corresponds to 3 channels and 1 class.

I believe in @junyuseu’s answer input means input to conv layer, not the input to the actual net work. In fact, unless your in_ch always equals to out_ch, the double_conv module will definitely throw such error.

image
The error is here. In the double_conv class, the first conv layer’s input dim is in_ch(3), out dim is out_ch(64), the second conv layer the same, but this layer’s input dim is 64, not in_ch(3)

9 Likes

Yup, made a silly mistake. Thanks for the help

i got the same kind of runtime error " Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 4, 224, 224] to have 3 channels, but got 4 channels instead "

Been trying to correct it for a while now but cant seem to see where the mistake is .
This is my model arch.


any help would be appreciated

Your input contains 4 channels, while the first conv layer expects an input with 3 channels.
If you are dealing with RGB images, the 4th channel might be the alpha channel, which could just be removed.
If you are using a custom DataLoader, you could probably just use:

def __getitem__(self, index):
    img = Image.open(self.paths[index]).convert('RGB')
    ...
    # Alternatively remove the alpha channel from the tensor
    img = Image.open(self.paths[index])
    x = TF.to_tensor(img)
    x = x[:3]
    ...
7 Likes

Thanks for your response! I had a different but related question.

If you are dealing with a 1-channel grayscale image but want to utilize a pretrained network (I am working off of this resnet18 based repo

Would you suggest duplicating the channel 3 times so it fits the model or is there a better approach?

I think this would be the easiest approach.
Alternatively, you could also try to reduce the filter channels (mean, sum, ?) and see, if that might give a better performance (I haven’t compared these approaches yet).

1 Like

I am getting the same error. Any idea where that might be in my code. I don’t seem to figure out where the mistake is…

    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()

            self.conv1 = nn.Conv2d(4, 7, (3, 3))
            # self.convnorm1 = nn.BatchNorm2d(7)
            # self.pool1 = nn.MaxPool2d((2, 2))

            self.conv2 = nn.Conv2d(7, 14, (3, 3))
            self.convnorm2 = nn.BatchNorm2d(14)
            # self.pool2 = nn.MaxPool2d((2, 2))
            #
            self.conv3 = nn.Conv2d(14, 28, (3, 3))
            # self.convnorm3 = nn.BatchNorm2d(28)
            # self.pool3 = nn.MaxPool2d((2, 2))

            self.conv4 = nn.Conv2d(28, 56, (3, 3))
            self.convnorm4 = nn.BatchNorm2d(56)
            # self.pool4 = nn.MaxPool2d((2, 2))

            self.conv5= nn.Conv2d(56, 112, (3, 3))
            self.convnorm5 = nn.BatchNorm2d(112)
            # self.pool5 = nn.MaxPool2d((2, 2))

            self.conv6 = nn.Conv2d(112, 224, (2, 2))
            self.convnorm6 = nn.BatchNorm2d(224)
            # self.pool6 = nn.MaxPool2d((2, 2))
            #
            # self.conv7 = nn.Conv2d(224, 448, (2, 2))
            # self.convnorm7 = nn.BatchNorm2d(448)
            # self.pool7 = nn.MaxPool2d((2, 2))

            self.linear1 = nn.Linear(47096, 200)
            self.linear1_bn = nn.BatchNorm1d(200)
            self.drop = nn.Dropout(DROPOUT)
            self.linear2 = nn.Linear(200, 17)
            self.act = torch.relu

        def forward(self, x):
            # x = self.act(self.conv1(x))
            # x = self.pool1(self.convnorm1(self.act(self.conv1(x))))
            #x = self.pool1(self.act(self.conv1(x)))
            x = self.act(self.conv1(x))
            #x = self.pool2(self.convnorm2(self.act(self.conv2(x))))
            x = self.convnorm2(self.act(self.conv2(x)))
            x = self.act(self.conv3(x))
            # x = self.pool3(self.convnorm3(self.act(self.conv3(x))))
            # x = self.pool4(self.convnorm4(self.act(self.conv4(x))))
            x =self.convnorm4(self.act(self.conv4(x)))
            # x = self.pool5(self.convnorm5(self.act(self.conv5(x))))
            # x = self.pool6(self.convnorm6(self.act(self.conv6(x))))
            # x = self.pool7(self.convnorm7(self.act(self.conv7(x))))
            x = self.drop((self.linear1_bn(self.act(self.linear1(x.view(len(x), -1))))))
            return self.linear2(x)

What is your input shape?
The code works fine using this snippet:

model = CNN()
x = torch.randn(2, 4, 37, 37)
out = model(x)
1 Like

@ptrblck sir please explain what is input shape …i forgot …and also …i am having a similar error with my code …plz help me figure it out …
here is the snippet::

-- coding: utf-8 --

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image
import torch.optim as optim

from torch.autograd import Variable
import numpy as np
import json
#loading the cascades

face_cascade =cv2.CascadeClassifier(“haarcascade_frontalface_default.xml”)
eye_cascade =cv2.CascadeClassifier(“haarcascade_eye.xml”)
#definnig a function that will do the detections
def detect(gray,frame):
faces=face_cascade.detectMultiScale(gray,1.3,5)
for (x,y,w,h) in faces :
cv2.rectangle(frame,(x,y),(x+w,y+h),(255,0,0),2)
roi_gray=gray[y:y+h,x:x+w]
roi_color=frame[y:y+h,x:x+w]
eyes=eye_cascade.detectMultiScale(roi_gray,1.1,3)
for (ex,ey,ew,eh) in eyes :
cv2.rectangle(roi_color,(ex,ey),(ex+ew,ey+eh),(0,255,0),2)
return frame
#applying the detect function

class CNN(nn.Module):
def init(self,num_inputs):
super(CNN, self).init()
self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.fc1 = nn.Linear(in_features =3233 , out_features = 40)
self.fc2 = nn.Linear(in_features = 40, out_features = 2)
def forward(self,x):
x = F.relu(F.max_pool2d(self.conv1(x), 3, 2))
x = F.relu(F.max_pool2d(self.conv2(x), 3, 2))
x = F.relu(F.max_pool2d(self.conv3(x), 3, 2))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
#train your neural network
cnn=CNN(3)#specify the number of input parameters

Setting some hyperparameters

batchSize = 4 # We set the size of the batch.
imageSize = 64 # We set the size of the generated images (64x64).
path=r"C:\Users\SHRIKANT\Music\Downloads\Simple_Linear_Regression\computer-vision\Module 3 - GANs"

Creating the transformations

#transform = transforms.Compose([transforms.Scale(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]) # We create a list of transformations (scaling, tensor conversion, normalization) to apply to the input images.

#dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2) # We use dataLoader to get the images of the training set batch by batch.
os.chdir(path)
loss = nn.MSELoss()

for epoch in range(25):
#train to find adarsh pandey

  for l in range(1,66):
      fak=Image.open(str(l)+".png").convert("RGB")
      img = Image.open(self.paths[index])
      x = TF.to_tensor(img)
      x = x[:3]
       #with open (str(l)+".png",'rb') as fak:  
      input=Variable(torch.from_numpy(np.array(fak,dtype='uint8'))).unsqueeze(0)
      #np.array( img, dtype='uint8' )
      output=cnn(input)
      optimizer = optim.Adam(cnn.parameters(), lr = 0.001)
      #what if optimizer.zero_grad()was here?? any effect .......plz check it 
      loss_error=loss(output,input)
      optimizer.zero_grad()
      
      loss_error.backward()
      optimizer.step()
  #training for identyfying of shubhi pandey
  
  
  #training ends here 
  probs=F.softmax(output)
  label=probs.multinomial()[0]
  print(label)

video_capture=cv2.VideoCapture(0)
#f_name=100
while True:

_,frame=video_capture.read()
gray=cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
canvas=detect(gray,frame)



#os.chdir(path)

#cv2.imwrite(str(f_name)+".png",frame)
cv2.imshow('Video',canvas)
#f_name=f_name+1
if cv2.waitKey(1) & 0xFF ==ord('q'):
    break

video_capture.release()
cv2.destroyAllWindows()

Input shape refers to the shape of your input tensor(s), which you are passing to your model.
E.g. nn.Conv2d modules expect a 4-dimensional tensor with the shape [batch_size, channels, height, width].

Could you remove unrelated code snippets (e.g. the opencv cascade etc.) and post a (minimal) code snippet to reproduce your error using random input data, please?

PS: You can add code snippets by wrapping them in three backticks ``` :wink:

1 Like

class LeNet(nn.Module):
def init(self):
super(LeNet, self).init()
self.conv1 = nn.Conv2d(1, 6, 5,bias=False)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    out = self.conv1(x)
    out = F.relu(out)
    out = F.max_pool2d(out, 2)
    out = F.relu(self.conv2(out))
    out = F.max_pool2d(out, 2)
    out = out.view(out.size(0), -1)
    out = F.relu(self.fc1(out))
    out = F.relu(self.fc2(out))
    out = self.fc3(out)
    return out

model = LeNet().to(device=device)

if args.model:
if os.path.isfile(args.model):
print("=> loading checkpoint ‘{}’".format(args.model))
checkpoint = torch.load(args.model)
args.start_epoch = checkpoint[‘epoch’]
best_prec1 = checkpoint[‘best_prec1’]
model.load_state_dict(checkpoint[‘state_dict’])
print("=> loaded checkpoint ‘{}’ (epoch {}) Prec1: {:f}".format(args.model, checkpoint[‘epoch’], best_prec1))
else:
print("=> no checkpoint found at ‘{}’".format(args.resume))

print(‘Pre-processing Successful!’)

def test(model):
kwargs = {‘num_workers’: 1, ‘pin_memory’: True} if args.cuda else {}
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(’./data’, train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)

getting this error: RuntimeError: Given groups=1, weight of size 3 6 5 5, expected input[256, 3, 14, 14] to have 6 channels, but got 3 channels instead
if i change bias = True then i am getting this error:
RuntimeError: Given weight of size 3 1 5 5, expected bias to be 1-dimensional with 3 elements, but got bias of size [6] instead

Answered here.

Hello, I have this dataset maker and I get a similar issue. I tried what you suggested but I get this error:

 return data[:3], target[:3], index
IndexError: invalid index to scalar variable.

Do you know why it is a scalar instead of a tensor?

class MyDataset(Dataset):
    def __init__(self,remove_list):
        self.cifar10 = datasets.CIFAR10(root='./data',

                                        download=False,
                                        train=True,
                                        transform=transform)
        self.data = self.cifar10.data
        # self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        # self.data = self.data.transpose((0, 2, 3, 1))
        self.targets = self.cifar10.targets
        self.final_data, self.final_targets = self.__remove__(remove_list)
      
    def __getitem__(self, index):
        data, target = self.final_data[index], self.final_targets[index]
        
        return data[:3], target[:3], index

    def __len__(self):
        return len(self.final_data)

    def __remove__(self, remove_list):
        data = np.delete(self.data, remove_list, axis=0)        
        targets = np.delete(self.targets, remove_list, axis=0)
        
        return data, targets

target could be a scalar, as you are calling self.final_targets with a single index.
Also data should be a single sample and I’m not sure what the [:3] indexing is doing.

Could you explain your use case a bit and what you are trying to achieve?

1 Like

I’m getting a very similar error.

“Given groups=1, weight of size [10, 3, 7, 7], expected input[1, 64, 64, 3] to have 3 channels, but got 64 channels instead”

I understand the input should be [batch, channels, height, width] but my environment returns me arrays of size (64, 64, 3). How do I make it (3, 64, 64), input it into the nn.Conv2d and get around this error?

Please help. @ptrblck

If you are loading the images with e.g. OpenCV, you would get an array in the shape [H, W, C] and would need to permute it to [C, H, W].
If you want to perform this permutation on the numpy array, you could use np.transpose, while tensor.permute should be used in PyTorch.

3 Likes