I was trying to replace the dataset in this
dataset: wayang_bagong_cepot_gareng_petruk_semar | Kaggle.
You can find the my clone and the implementation in https://github.com/jameswong3388/vision-transformer/blob/master/src/dataset.py
Here is what i did to replace CIFAR10DataModule
class WayangKulitDataModule(pl.LightningDataModule):
def __init__(self, batch_size: int = 32, patch_size: int = 4, val_batch_size: int = 16,
im_size: int = 32, rotation_degrees: (int, int) = (-30, 30)):
super().__init__()
self.batch_size = batch_size
self.val_batch_size = val_batch_size
# ViTFeatureExtractor
# {
# "do_normalize": true,
# "do_resize": true,
# "feature_extractor_type": "ViTFeatureExtractor",
# "image_mean": [
# 0.485,
# 0.456,
# 0.406
# ],
# "image_std": [
# 0.229,
# 0.224,
# 0.225
# ],
# "resample": 3,
# "size": 224
# }
self.train_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(size=(im_size, im_size)),
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size=(im_size, im_size)),
transforms.RandomRotation(degrees=rotation_degrees),
transforms.ToTensor(),
transforms.Normalize((0.63528919, 0.57810118, 0.51988552), (0.33020571, 0.34510824, 0.36673283)),
PatchifyTransform(patch_size),
]
)
self.val_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(size=(im_size, im_size)),
# transforms.Lambda(lambda x: x.squeeze(0)),
transforms.ToTensor(),
transforms.Normalize((0.63528919, 0.57810118, 0.51988552), (0.33020571, 0.34510824, 0.36673283)),
PatchifyTransform(patch_size),
]
)
self.ds_train = None
self.ds_val = None
def setup(self, stage: str):
self.ds_train = WayangKulit('datas/train', transform=self.train_transform)
self.ds_val = WayangKulit('datas/val', transform=self.val_transform)
def train_dataloader(self):
return DataLoader(self.ds_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.ds_val, batch_size=self.batch_size)
@property
def classes(self):
"""Returns the amount of WayangKulitDataModule classes"""
return 5 # Wayang kulit dataset has 5 possible classes
class WayangKulit(Dataset):
def __init__(self, dataset_path, transform):
self.dataset_path = dataset_path
self.transform = transform
self.dataset = load_dataset(path=self.dataset_path, name='wayang_kulit')
self.train_image_paths = []
self.classes = []
for data_path in glob.glob(self.dataset_path + '/*'):
self.classes.append(data_path.split('/')[-1])
self.train_image_paths.append(glob.glob(data_path + '/*'))
self.train_image_paths = list(flatten(self.train_image_paths))
random.shuffle(self.train_image_paths)
def __len__(self):
return len(self.dataset['train'])
def __getitem__(self, idx):
image_filepath = self.train_image_paths[idx]
image = cv2.imread(image_filepath) # (225, 175, 3)
if self.transform:
image = self.transform(image)
idx_to_class = {i: j for i, j in enumerate(self.classes)}
class_to_idx = {value: key for key, value in idx_to_class.items()}
label = image_filepath.split('/')[-2]
label = class_to_idx[label]
if self.transform:
image = self.transform(image)
return image, label
Since the size of the images in the dataset are 175 x 225, i had to resize to 32x32 in order to works in the original repo,
but i got the following errors
Traceback (most recent call last):
File "/Users/jameswong/PycharmProjects/AIM/vision_transformer-main/train.py", line 72, in <module>
trainer.fit(model, data)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
call._call_and_handle_interrupt(
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run
results = self._run_stage()
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage
self._run_train()
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1204, in _run_train
self._run_sanity_check()
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1276, in _run_sanity_check
val_loop.run()
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 121, in advance
batch = next(data_fetcher)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in __next__
return self.fetching_function()
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 265, in fetching_function
self._fetch_next_batch(self.dataloader_iter)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch
batch = next(iterator)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/Users/jameswong/PycharmProjects/AIM/vision_transformer-main/src/dataset.py", line 217, in __getitem__
image = self.transform(image)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
img = t(img)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 277, in forward
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 363, in normalize
return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
File "/Users/jameswong/PycharmProjects/AIM/venv/lib/python3.9/site-packages/torchvision/transforms/_functional_tensor.py", line 926, in normalize
return tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32]