Maybe this code snippet might be helpful:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.arange(100).view(100, 1).float()
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
sampler = torch.utils.data.sampler.BatchSampler(
torch.utils.data.sampler.RandomSampler(dataset),
batch_size=10,
drop_last=False)
loader = DataLoader(
dataset,
sampler=sampler)
for data in loader:
print(data)
The index
inside __getitem__
will contain 10 random indices, which are used to create the batch in:
x = self.data[index]
You could replace this with your expensive data loading operation.
Let me know, if that helps.