Caltech256 containing grayscale images

I was trying to use the Caltech256 dataset to finetune the resnet 50 model. The images contained in the dataset are supposed to be 3 channel rgb images. However, when applying transformations, I got an error because some pictures only contained one channel.

from torchvision import models, datasets

resnet_weights = models.ResNet50_Weights.DEFAULT
caltech256 = datasets.Caltech256(
    root="data",
    download=True,
    transform=resnet_weights.transforms()
)

for i in range(len(caltech256)):
  print(i)
  print(caltech256[i][0].shape)

Running this code returns the error

RuntimeError                              Traceback (most recent call last)

<ipython-input-17-ed84e6e9f7b9> in <module>
      1 for i in range(len(caltech256)):
      2   print(i)
----> 3   print(caltech256[i][0].shape)

4 frames

/usr/local/lib/python3.9/dist-packages/torchvision/datasets/caltech.py in __getitem__(self, index)
    211 
    212         if self.transform is not None:
--> 213             img = self.transform(img)
    214 
    215         if self.target_transform is not None:

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/torchvision/transforms/_presets.py in forward(self, img)
     59             img = F.pil_to_tensor(img)
     60         img = F.convert_image_dtype(img, torch.float)
---> 61         img = F.normalize(img, mean=self.mean, std=self.std)
     62         return img
     63 

/usr/local/lib/python3.9/dist-packages/torchvision/transforms/functional.py in normalize(tensor, mean, std, inplace)
    358         raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
    359 
--> 360     return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
    361 
    362 

/usr/local/lib/python3.9/dist-packages/torchvision/transforms/functional_tensor.py in normalize(tensor, mean, std, inplace)
    938     if std.ndim == 1:
    939         std = std.view(-1, 1, 1)
--> 940     return tensor.sub_(mean).div_(std)
    941 
    942 

RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

From the print statements during execution it can be seen, that it break for i=15.

Checking this without the normalization it can be seen that the sample for i=15 contains only one channel:

from torchvision import models, datasets, transforms

resnet_weights = models.ResNet50_Weights.DEFAULT
caltech256 = datasets.Caltech256(
    root="data",
    download=True,
    transform=transforms.ToTensor()
)
print(caltech256[15][0].shape)

Is there something I am missing? How should one deal with that problem?

Based on this overview it seems both image types are expected:

Caltech256 dataset. (RGB and grayscale images of various sizes in 256 categories for a total of 30608 images).

However, I don’t see an easy way to convert the images to RGB besides deriving a custom Dataset and calling img = img.convert("RGB") after this Image.Open call.

Thanks for your reply. I did not know that mixing RGB and grayscale image was a viable option in training models.

I will therefore look into creating a custom dataset.

An other possible solution is to replicate the grey channel three times in order to create a new tensor with dimension [3, w, h]. In my case I have to compute the mean and std in order to normalize the random split, and I’ve done like it follow:

def init_caltech256():
    if len(cls_datasets['caltech256']['train_idxs']) == 0 and len(cls_datasets['caltech256']['test_idxs']) == 0:
        
        download_caltech256()
        
        logger.info(' => Computing mean and std for a random train test split...')
        
        rand_perm = torch.randperm(cls_datasets['caltech256']['n_images'])
        cls_datasets['caltech256']['train_idxs'] = rand_perm[:10000].tolist()
        cls_datasets['caltech256']['test_idxs'] = rand_perm[:10000].tolist()
                
        train_data = datasets.Caltech256('./datasets/', download=False, transform=v2.Compose([v2.Resize((224,224))]))
        
        images = []
        for i in cls_datasets['caltech256']['train_idxs']:
            if len(np.asarray(train_data[i][0]).shape) == 3: images.append(np.asarray(train_data[i][0]))
            else: images.append(np.repeat(np.asarray(train_data[i][0])[:, :, np.newaxis], 3, axis=2))
                
        images = np.concatenate(images)
                
        mean = np.mean(images, axis=(0, 1)) / 255 
        std = np.std(images, axis=(0, 1)) / 255
                
        cls_datasets['caltech256']['transforms']['train'].append(v2.Normalize(mean=mean, std=std))
        cls_datasets['caltech256']['transforms']['train'] = v2.Compose(cls_datasets['caltech256']['transforms']['train'])
                
        cls_datasets['caltech256']['transforms']['test'].append(v2.Normalize(mean=mean, std=std))
        cls_datasets['caltech256']['transforms']['test'] = v2.Compose(cls_datasets['caltech256']['transforms']['test'])

        logger.info(' DONE\n')