How to use BatchSampler with __getitem__ dataset

Maybe this code snippet might be helpful:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(100).view(100, 1).float()
        
    def __getitem__(self, index):
        x = self.data[index]
        return x
    
    def __len__(self):
        return len(self.data)

dataset = MyDataset()        
sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(dataset),
    batch_size=10,
    drop_last=False)

loader = DataLoader(
    dataset,
    sampler=sampler)

for data in loader:
    print(data)

The index inside __getitem__ will contain 10 random indices, which are used to create the batch in:

x = self.data[index]

You could replace this with your expensive data loading operation.

Let me know, if that helps.

6 Likes