You could adapt this code:
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self):
self.cifar10 = datasets.CIFAR10(root='YOUR_PATH',
download=False,
train=True,
transform=transforms.ToTensor())
def __getitem__(self, index):
data, target = self.cifar10[index]
# Your transformations here (or set it in CIFAR10)
return data, target, index
def __len__(self):
return len(self.cifar10)
dataset = MyDataset()
loader = DataLoader(dataset,
batch_size=1,
shuffle=True,
num_workers=1)
for batch_idx, (data, target, idx) in enumerate(loader):
print('Batch idx {}, dataset index {}'.format(
batch_idx, idx))