I have searched online and the forums and the only solution so far I have found is to use this external library: GitHub - szymonmaszke/torchdata: PyTorch dataset extended with map, cache etc. (tensorflow.data like)
So before I use that I was wondering whether there is a solution native to PyTorch.
Here is my dataset:
class CustomDataset(Dataset):
def __init__(self, fns, transform=None):
self.fns = fns
self.transform = transform
def __len__(self):
return len(self.fns)
def __getitem__(self, idx):
image = Image.open(self.fns[idx]).convert("RGB")
label = self.fns[idx].parent.name
if self.transform is not None:
# apply transforms here
return image, label
For context, the images are all in respective folders but not split to train and valid.
Therefore I use this code to split it randomly:
dset = CustomDataset(filenames)
num_items = len(dset)
num_train = round(num_items * 0.8)
num_val = num_items - num_train
train_ds, val_ds = random_split(dset, [num_train, num_val])
The problem now is now that I have the specific datasets, I need to apply specific transforms depending if they are training data or valid data.
Im guessing I could create a class function like def split_data(self,....):
inside the class but I’m not sure how to proceed.
Any leads? Or any better way to construct this custom dataset that does the splitting for us?