Hi gregunz, apologize for late reply.
You code is great, but it needs to change a bit in ‘getitem’ to access iamges and labels in my case.
I have used your code and the code here Using ImageFolder, random_split with multiple transforms.
The resulting code coulde work for me. Let me know if my I did it correctly. You can refine this code if there are any mistakes and then I will accept it as a solution.
class MapDataset(torch.utils.data.Dataset):
"""
Given a dataset, creates a dataset which applies a mapping function
to its items (lazily, only when an item is called).
Note that data is not cloned/copied from the initial dataset.
"""
def __init__(self, dataset, map_fn):
self.dataset = dataset
self.map = map_fn
def __getitem__(self, index):
if self.map:
x = self.map(self.dataset[index][0])
else:
x = self.dataset[index][0] # image
y = self.dataset[index][1] # label
return x, y
def __len__(self):
return len(self.dataset)