Sampler with set indices

I have a specific order that I would like to feed into my data loader. I already made a list of indices for each batch.

Which sampler on PyTorch can I use to do this?

I’m not sure if you plan to pass the indices of the entire batch to the Dataset.__getitem__ and would like to load all samples for this batch already (instead of a single sample which would be the standard use case), but if so then creating a custom Sampler might work as seen here:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(20).view(-1, 1).float()
        
    def __getitem__(self, index):
        print("received indices {}".format(index))
        x = self.data[index]
        return x
    
    def __len__(self):
        return len(self.data)
     

class MySampler(torch.utils.data.sampler.Sampler):
    def __init__(self, indices):
        self.indices = indices
        
    def __iter__(self):
        return iter(self.indices)


# create indices for each batch
indices = [torch.arange(5, 10),
           torch.arange(0, 5),
           torch.arange(15, 20),
           torch.arange(10, 15),
]

sampler = MySampler(indices)
dataset = MyDataset()   
loader = DataLoader(dataset, sampler=sampler)
for data in loader:
    print(data)

Output:

# received indices tensor([5, 6, 7, 8, 9])
# tensor([[[5.],
#          [6.],
#          [7.],
#          [8.],
#          [9.]]])
# received indices tensor([0, 1, 2, 3, 4])
# tensor([[[0.],
#          [1.],
#          [2.],
#          [3.],
#          [4.]]])
# received indices tensor([15, 16, 17, 18, 19])
# tensor([[[15.],
#          [16.],
#          [17.],
#          [18.],
#          [19.]]])
# received indices tensor([10, 11, 12, 13, 14])
# tensor([[[10.],
#          [11.],
#          [12.],
#          [13.],
#          [14.]]])

Is this what you are looking for?