Convert grayscale images to RGB

Hello, I am trying to classify ImageNet using vgg and I am using a custom dataset as follows

train_dataset=CustomDataset(csv_file='/home/tboonesifuentes/Databases/ImageNet/Train/train.csv',root_dir='/home/tboonesifuentes/Databases/ImageNet/Train/Crops',
                                transform=transforms.Compose([  
                                transforms.ToPILImage(),
                                transforms.Resize([224,224]),
                                transforms.ToTensor()]))
    
    
test_dataset=CustomDataset(csv_file='/home/tboonesifuentes/Databases/ImageNet/Test/test.csv',root_dir='/home/tboonesifuentes/Databases/ImageNet/Test/Crops',
                                transform=transforms.Compose([
                                transforms.ToPILImage(),
                                transforms.Resize([224,224]),
                                #transforms.RandomCrop(24),
                                transforms.ToTensor()]))


batch_size=130


class TransformedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform_fn):
        self.dataset = dataset
        self.transform_fn = transform_fn

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        dp = self.dataset[index]
        return self.transform_fn(dp)
    

labels_mapping = {     
8:0,
9:1,
10:2,
11:3,
12:4,
13:5,
14:6,
15:7,
16:8,
17:9,
18:10,
19:11,
20:12,
21:13,
22:14,
23:15,
24:16,
25:17,
26:18,
27:19,
28:20,
29:21,
30:22,
31:23,
32:24,
33:25,
34:26,
35:27,
36:28,
37:29,
38:30,
39:31,
40:32,
41:33,
42:34,
43:35,   
}

def map_targets_fn(dp, target_mapping):
    x, y = dp
    new_y  = target_mapping[y.item()]
    return x, new_y


train_dataset = TransformedDataset(train_dataset, partial(map_targets_fn, target_mapping=labels_mapping))
test_dataset = TransformedDataset(test_dataset, partial(map_targets_fn, target_mapping=labels_mapping))

for idx, (data,image) in enumerate (train_dataset):
    
    if data.shape[0] == 1:
        print(data.shape)
        print('1D image')
        


train_loader = DataLoader(train_dataset, batch_size,num_workers=num_workers, 
                        shuffle=True, drop_last=True)    
  
test_loader = DataLoader(test_dataset, batch_size,num_workers=num_workers, 
                          shuffle=False, drop_last=False)    

I didn’t know what ImageNet had grayscale images and I actually found some and read them on matlab and yes they are grayscale…that’s the reason Im getting the error of batch size mismatch at position 0. Now I know I have to convert these grayscale images if I want to train…my question is where can I catch the grayscale images and convert them to rgb? In matlab would be something like rgbImage = cat(3, A,A, A); where A is the grayscale image. But I don’t know how to do it or where exactly on my special code. Please someone help :slight_smile:

Assuming the tensors are loaded as [channels, height, width], you could probably use this lambda transformation:

trans = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)

x = torch.randn(3, 224, 224)
out = trans(x)
print(out.shape)
> torch.Size([3, 224, 224])

x = torch.randn(1, 224, 224)
out = trans(x)
print(out.shape)
> torch.Size([3, 224, 224])

If you are loading the images via PIL.Image.open inside your custom Dataset, you could also convert them directly to RGB via PIL.Image.open(...).convert('RGB').
However, since you are using ToPILImage as a transformation, I assume you are loading tensors directly.

1 Like

Hello ptrblck, Thanks for your quick response. Actually I discovered I also have images with four channels so I implemented this code in my custom dataset

import os
import pandas as pd
import torch

from torch.utils.data import Dataset

from skimage import io

class CustomDataset(Dataset):
	
    def __init__(self,csv_file,root_dir,transform=None):
        self.annotations=pd.read_csv(csv_file)
        self.root_dir=root_dir
        self.transform=transform

    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self,index):
        img_path=os.path.join(self.root_dir,self.annotations.iloc[index,0])
        image=io.imread(img_path)
        
        if len(image.shape) == 2:
        #convert grayscale to RGB
        #image = Image.open(path).convert('RGB') 
            image=torch.from_numpy(image)
            image=torch.stack([image,image,image],0)
            image=torch.transpose(image,0,2)
            image=image.numpy()
            #print('this was 1d before')
            #print(image.shape)
        elif len(image.shape) == 3: 
        #image has 4 channels
            if image.shape[0]==4:
                image=torch.from_numpy(image)
                image = image[:,:,:3]
                image=torch.transpose(image,0,2)
                image=image.numpy()
            
        y_label=torch.tensor(int(self.annotations.iloc[index,2]))

        if self.transform:
            image=self.transform(image)
        return (image,y_label)