Force DataLoader to fetch batched index from custom batch sampler

I think the BatchSampler will make sure to pass all batch indices to your Dataset's __getitem__ method as seen in this example:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(100).view(100, 1).float()
        
    def __getitem__(self, index):
        print(index)
        x = self.data[index]
        return x
    
    def __len__(self):
        return len(self.data)

dataset = MyDataset()        
sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(dataset),
    batch_size=10,
    drop_last=False)

loader = DataLoader(
    dataset,
    sampler=sampler)

for data in loader:
    print(data)

If you run it, you’ll see that the index inside the Dataset.__getitem__ will contain 10 indices, which can be used to slice the data directly.

Let me know, if I misunderstood the question.

5 Likes