Why does Pytorch claim my input has 4 channels?

I was trying to train an SRGAN-model on my dataset of text images. Those are 100 160x30 .png files.

Here is the model: https://github.com/leftthomas/SRGAN/blob/master/model.py
https://github.com/leftthomas/SRGAN/blob/master/train.py
(I didn’t change anything apart from the path to my dataset)

I set the crop_size to 24, the upscale to 2 and start the training, but then the following error occurs:

RuntimeError: Given groups=1, weight of size [64, 3, 9, 9], expected input[64, 4, 12, 12] to have 3 channels, but got 4 channels instead

I am sure that all of my images have 3 channels, which is confirmed by img.shape.

Nevertheless, I am stuck and,honestly, I have no idea where exactly this fourth channel might be.

Can you please help me figure out what is wrong. I have googled similar discussions, but the most frequent piece of advice there is to make sure the input is not rgba, which in my case is not applicable.

Thanks.

I guess the 4. channel might be the alpha channel (in some images).
If you are using a custom Dataset, you could easily check if and slice the input image if necessary.

Thanks for answering. I have checked that several times. None of my images have the alpha channel. This is what makes me wonder…

You could try to iterate once over the data loader and check the size of the inputs, just to be sure?

This might help you locate the error more precisely (e.g. if all images have 4 channels, or if they have 3 channels up to this point and are modified later on…)

Hi again,

I had the same issue after loading all graphs direct from seaborn. I follow the face landmarks tutorial and your tips on how to load the data into the data loader. I successfully load the data, saw the images ( so they were there), but then to my surprise

0 torch.Size([4, 224, 224])
1 torch.Size([4, 224, 224])

So i follow another topic you had on how to fix this then i used the code below:

data_transforms = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225] )])

transformed_dataset = FaceLandmarksDataset(csv_file='metadata.csv',
                                       root_dir="/figures/",
                                       transform = data_transforms )

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size())

    if i == 3:
        break

dataloader = DataLoader(transformed_dataset, batch_size=4,
                    shuffle=True, num_workers=4)

And I get the following error:

TypeError: img should be PIL Image. Got <class ‘dict’>

seems that you are following the tutorial that uses skimage.io to read the image which returns a numpy array, or some other object. You will either need to transform the image to PIL or, replace skimage imread with PIL image open, something like:

from PIL import Image
image = Image.open(img_name)
1 Like

It turns out the issue is on the transforms normalise, i was using 3 channels for a 1 channel data. I now manage to easily load, shuffle and reload the data using imageFolder ( Thanks to Patrick ) and some other techniques I picked here in the forum.

Thanks for all the help.