RuntimeError: output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32] during transformations

I was trying to replace the dataset in this

repo: GitHub - MikhailKravets/vision_transformer: Discover how to build vision transformer from scratch with this comprehensive tutorial. Follow our step-by-step guide to create your own vision transformer., with this

dataset: wayang_bagong_cepot_gareng_petruk_semar | Kaggle.

You can find the my clone and the implementation in https://github.com/jameswong3388/vision-transformer/blob/master/src/dataset.py

Here is what i did to replace CIFAR10DataModule

class WayangKulitDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 32, patch_size: int = 4, val_batch_size: int = 16,
                 im_size: int = 32, rotation_degrees: (int, int) = (-30, 30)):
        super().__init__()
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size

        # ViTFeatureExtractor
        # {
        #     "do_normalize": true,
        #     "do_resize": true,
        #     "feature_extractor_type": "ViTFeatureExtractor",
        #     "image_mean": [
        #         0.485,
        #         0.456,
        #         0.406
        #     ],
        #     "image_std": [
        #         0.229,
        #         0.224,
        #         0.225
        #     ],
        #     "resample": 3,
        #     "size": 224
        # }

        self.train_transform = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(size=(im_size, im_size)),

                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(size=(im_size, im_size)),
                transforms.RandomRotation(degrees=rotation_degrees),

                transforms.ToTensor(),
                transforms.Normalize((0.63528919, 0.57810118, 0.51988552), (0.33020571, 0.34510824, 0.36673283)),
                PatchifyTransform(patch_size),
            ]
        )
        self.val_transform = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(size=(im_size, im_size)),

                # transforms.Lambda(lambda x: x.squeeze(0)),

                transforms.ToTensor(),
                transforms.Normalize((0.63528919, 0.57810118, 0.51988552), (0.33020571, 0.34510824, 0.36673283)),
                PatchifyTransform(patch_size),
            ]
        )

        self.ds_train = None
        self.ds_val = None

    def setup(self, stage: str):
        self.ds_train = WayangKulit('datas/train', transform=self.train_transform)
        self.ds_val = WayangKulit('datas/val', transform=self.val_transform)

    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=self.batch_size)

    @property
    def classes(self):
        """Returns the amount of WayangKulitDataModule classes"""
        return 5  # Wayang kulit dataset has 5 possible classes


class WayangKulit(Dataset):
    def __init__(self, dataset_path, transform):
        self.dataset_path = dataset_path
        self.transform = transform
        self.dataset = load_dataset(path=self.dataset_path, name='wayang_kulit')

        self.train_image_paths = []
        self.classes = []

        for data_path in glob.glob(self.dataset_path + '/*'):
            self.classes.append(data_path.split('/')[-1])
            self.train_image_paths.append(glob.glob(data_path + '/*'))

        self.train_image_paths = list(flatten(self.train_image_paths))
        random.shuffle(self.train_image_paths)

    def __len__(self):
        return len(self.dataset['train'])

    def __getitem__(self, idx):
        image_filepath = self.train_image_paths[idx]

        image = cv2.imread(image_filepath) # (225, 175, 3)

        if self.transform:
            image = self.transform(image)

        idx_to_class = {i: j for i, j in enumerate(self.classes)}
        class_to_idx = {value: key for key, value in idx_to_class.items()}

        label = image_filepath.split('/')[-2]
        label = class_to_idx[label]

        if self.transform:
            image = self.transform(image)

        return image, label

Since the size of the images in the dataset are 175 x 225, i had to resize to 32x32 in order to works in the original repo,

but i got the following errors

Traceback (most recent call last):
  File "/Users/jameswong/PycharmProjects/AIM/vision_transformer-main/train.py", line 72, in <module>
    trainer.fit(model, data)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run
    results = self._run_stage()
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage
    self._run_train()
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1204, in _run_train
    self._run_sanity_check()
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1276, in _run_sanity_check
    val_loop.run()
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 121, in advance
    batch = next(data_fetcher)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in __next__
    return self.fetching_function()
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 265, in fetching_function
    self._fetch_next_batch(self.dataloader_iter)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch
    batch = next(iterator)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/jameswong/PycharmProjects/AIM/vision_transformer-main/src/dataset.py", line 217, in __getitem__
    image = self.transform(image)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 277, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 363, in normalize
    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
  File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torchvision/transforms/_functional_tensor.py", line 926, in normalize
    return tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32]

Based on the error message (some of) your images have a single color channel (and are thus grayscale) while your normalization expects RGB images. Either repeat the grayscale channel to create 3 channels or change the normalization.

Just found the problem, i transform the image twice :laughing: