Loaded grayscale image is converted into 3 channel

I am using a dataset of grayscale images of size 64X64 to train a network for two class image classification. The images are stored in train folder with class labels as subfolder name inside train folder but when ever I am loading the images it is converting into 3 channel images. The piece of code is:

data_dir_train = ‘C:/Users/rukhm/nodule_classification/data/train/’
data_dir_test = ‘C:/Users/rukhm/nodule_classification/data/test/’
train_dataset = torchvision.datasets.ImageFolder(data_dir_train,transform = transforms.ToTensor())
test_dataset = torchvision.datasets.ImageFolder(data_dir_test,transform = transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True);
testloader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True);
dataiter = iter(trainloader)
images, labels = dataiter.next()

print(images.shape) #shape of all images
print(images[1].shape) #shape of one image
print(labels[1].item()) #label number

Output:
torch.Size([100, 3, 64, 64])
torch.Size([3, 64, 64])
0

Could you try to load a single image via img = PIL.Image.open(path) and check its shape?
If that returns an RGB image, you could add transforms.GrayScale to the transformations.

img = Image.open(‘C:/Users/rukhm/nodule_classification/data/train/1/LIDC-IDRI-0003_74_1.png’)
im = (transforms.ToTensor()(img))
im.shape
Output:
torch.Size([1, 64, 64])
For a single image it is showing as a grayscale image.
Is it possible to check the shape of train_dataset/test_dataset?
I need to check the line where it is converted into a 3 channel image.
Thanks for the help!!

I am having the same problem, no matter what I do I cannot load my greyscale images with 1 channel

I am currently facing the same issue. My .bmp images with 1 channel are somehow converted into 3 channels. As I have found out now the issue seems to be caused by the ImageFolder function offered by pytorch. Is there any possibility to avoid ImageFolder converting the images into 3 channels? The only way I currently see is to use the grayscale transform, as a counter.

ImageFolder will use the default_loader, which would then either use accimage_loader, if it’s available, or pil_loader, which will convert the loaded images to RGB as seen here.
You can create a custom loader by e.g. just reusing the pil_loader and removing the convert call, and pass it into the loader argument in ImageFolder.

1 Like

Doing what ptrblck suggested is definitely the “proper” way of doing this.

But if you want a quick fix that does not involve a custom loader you can just use torchvision.transforms.Grayscale(num_output_channels=1) (doc) as an extra transformation.

In your above code:

It would look like this:

train_dataset = datasets.ImageFolder(data_dir_train,transform = 
                                     transforms.Compose([transforms.Grayscale(num_output_channels=1), 
                                                         transforms.ToTensor()]))

Since all PILs Image.convert('RGB') does is r==g==b
and transforms.Grayscale(num_output_channels=1) does (0.2989 * r + 0.587 * g + 0.114 * b)
this should not do too much to your data except for maybe a small change in the 0.00x margin.

1 Like

Thanks for you help!
I want to try to implement a custom pil loader but I am currently struggling a bit with this.

My plan was to add the custom pil loader within folder.py.
So I copied the normal pil_loader function, where ptrblck showed me it is located, and pasted it right above it.
I renamed it to custom_pil_loader and removed the convert call.

def custom_pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img

def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

After this I thought that it would be enough to call the new loader during the ImageFolder call.
I tried this here but it didn`t worked.
Could you tell me what I am doing wrong?

 train_ds = ImageFolder(os.path.join(opt.dataroot, 'train'), transform, loader='custom_pil_loader')
 valid_ds = ImageFolder(os.path.join(opt.dataroot, 'test'), transform, loader='custom_pil_loader')

That was the way I called the ImageFolder before what worked well, with the problem of the conversion in RGB of single channel images.

train_ds = ImageFolder(os.path.join(opt.dataroot, 'train'), transform)
valid_ds = ImageFolder(os.path.join(opt.dataroot, 'test'), transform)

I tried then to access the first batch of images within the dataloader

train_dl = DataLoader(dataset=train_ds, batch_size=opt.batchsize, shuffle=True, drop_last=True)
valid_dl = DataLoader(dataset=valid_ds, batch_size=opt.batchsize, shuffle=False, drop_last=False)

I run the following command

iter(train_dl).next()[0].shape

With the default_loader I get the following output

Out[1]: torch.Size([32, 3, 32, 32])

But when I try to run this command using my custom_pil_loader I get the following error message

Traceback (most recent call last):
  File "/fibus/fs1/16/cql7772/.local/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 
3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-405944bb4251>", line 1, in <module>
    iter(train_dl).next()[0].shape
  File "/fibus/fs1/16/cql7772/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 
346, in __next__
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/fibus/fs1/16/cql7772/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, 
in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/fibus/fs1/16/cql7772/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, 
in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/fibus/fs1/16/cql7772/.local/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 
138, in __getitem__
    sample = self.loader(path)
TypeError: 'str' object is not callable

What you did works.
I tried it. You just need to give the callable function custom_pil_loader when calling ImageFolder and not the name as a string.

The following worked for me:

import torch
import torchvision.transforms as transforms
import torchvision.datasets as dset
from PIL import Image

DATA_ROOT = '../data/dataset/'

def custom_pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        img.load()
        return img

transform=transforms.Compose([transforms.Resize(256),
                              transforms.CenterCrop(256),
                              transforms.ToTensor(),
                              ])
    
default_ds = dset.ImageFolder(DATA_ROOT, transform,)
custom_ds = dset.ImageFolder(DATA_ROOT, transform, loader=custom_pil_loader)

default_dl = torch.utils.data.DataLoader(dataset=default_ds, batch_size=32)
custom_dl = torch.utils.data.DataLoader(dataset=custom_ds, batch_size=32)
     
print(iter(default_dl).next()[0].shape)
print(iter(custom_dl).next()[0].shape)
torch.Size([32, 3, 256, 256])
torch.Size([32, 1, 256, 256])

I had to also add img.load() at the end of the custom loader or I would get ValueError: seek of closed file.
Not sure what this is about tho. I found this fix on stackoverflow

2 Likes

Many thanks RaLo4!
I did it your way and it works perfectly well!
I had the same problem with the

 ValueError: seek of closed file

when I commented out the

img.load()

It seems to be necessary.

1 Like