RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[4, 5000, 5000, 3] to have 3 channels, but got 5000 channels instead

The dataloader is like this.

from torch.utils.data import Dataset
import glob
import skimage
import torch

class ImageLoader():
    def __init__(self, Images, Annotations,
                 train_percentage):
        self.Image = glob.glob(Images)
        self.Annotations = glob.glob(Annotations)
        self.train_percentage = train_percentage
        train_len = int(train_percentage * len(Images))
        self.train_set = {"Images": self.Image[:train_len],
                          "Annotations": self.Annotations[:train_len]}
        self.test_set = {"Images": self.Image[train_len:],
                         "Annotations": self.Annotations[train_len:]}


class TrainSet(Dataset):
    def __init__(self, train_data, extension="jpeg", transform=None):
        self.extension = extension.lower()
        self.transform = transform
        self.images = train_data["Images"]
        self.target_images = train_data["Annotations"]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        if self.extension == "png":
            image = skimage.io.imread(self.images[index])[:3]
            label = skimage.io.imread(self.target_images)[:3]
        if self.extension == "tif":
            image = skimage.external.tifffile.imread(self.images[index])
            label = skimage.external.tifffile.imread(self.target_images[index])
        else:
            image = skimage.io.imread(self.images[index])
            label = skimage.io.imread(self.target_images[index])
        if self.transform:
            image = self.transform(image)
        return {"Image": torch.from_numpy(image), "Label": torch.from_numpy(label)}


class TestSet(Dataset):
    def __init__(self, train_data, extension="jpeg", transform=None):
        self.extension = extension.lower()
        self.transform = transform
        self.images = train_data["Images"]
        self.target_images = train_data["Annotations"]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        if self.extension == "png":
            image = skimage.io.imread(self.images[index])[:3]
            label = skimage.io.imread(self.target_images)[:3]
        if self.extension == "tif":
            image = skimage.external.tifffile.imread(self.images[index])
            label = skimage.external.tifffile.imread(self.target_images[index])
        else:
            image = skimage.io.imread(self.images[index])
            label = skimage.io.imread(self.target_images[index])
        if self.transform:
            image = self.transform(image)
        return {"Image": torch.from_numpy(image), "Label": torch.from_numpy(label)}

The model is this.

import torch
import torch.nn as nn


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class UNeT(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
                                    align_corners=True)
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        x = self.dconv_down4(x)
        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)
        out = self.conv_last(x)
        return out

I am feeding 5000x5000x3 images in the network. I have read the similar problems but unsqeezing the input did not fix the problem.

It looks like you are passing the input as “channel last”.
Currently you have to pass image tensors as [batch_size, channels, height, width].
You could use x.permute(0, 3, 1, 2) to permute the dimensions.

I assume you are trying to tell me to permute the dimensions of the images and the label images right, because they need to be given to the network as [b, c, h, w] and I’m giving in a [b, h, w, c] format. I tried to do inputs.permute(0, 3, 1, 2) and labels…permute(0, 3, 1, 2) However I get RuntimeError: number of dims don't match in permute Also I think the tensor permute documentation needs to be more informative.

Maybe the batch dimension is missing for the tensors you would like to permute (e.g. if you are applying it inside __getitem__), so could you check the shape of the tensor.

Thanks for the feedback regarding the docs. Do you have any suggestion what to add to make the docs more informative?

Thank you, I was able to solve it. And for the documentation although it is pretty self explanatory but it should be something like swaps the axes in the order of the *dims.