I have a specific order that I would like to feed into my data loader. I already made a list of indices for each batch.
Which sampler on PyTorch can I use to do this?
I have a specific order that I would like to feed into my data loader. I already made a list of indices for each batch.
Which sampler on PyTorch can I use to do this?
I’m not sure if you plan to pass the indices of the entire batch to the Dataset.__getitem__
and would like to load all samples for this batch already (instead of a single sample which would be the standard use case), but if so then creating a custom Sampler
might work as seen here:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.arange(20).view(-1, 1).float()
def __getitem__(self, index):
print("received indices {}".format(index))
x = self.data[index]
return x
def __len__(self):
return len(self.data)
class MySampler(torch.utils.data.sampler.Sampler):
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return iter(self.indices)
# create indices for each batch
indices = [torch.arange(5, 10),
torch.arange(0, 5),
torch.arange(15, 20),
torch.arange(10, 15),
]
sampler = MySampler(indices)
dataset = MyDataset()
loader = DataLoader(dataset, sampler=sampler)
for data in loader:
print(data)
Output:
# received indices tensor([5, 6, 7, 8, 9])
# tensor([[[5.],
# [6.],
# [7.],
# [8.],
# [9.]]])
# received indices tensor([0, 1, 2, 3, 4])
# tensor([[[0.],
# [1.],
# [2.],
# [3.],
# [4.]]])
# received indices tensor([15, 16, 17, 18, 19])
# tensor([[[15.],
# [16.],
# [17.],
# [18.],
# [19.]]])
# received indices tensor([10, 11, 12, 13, 14])
# tensor([[[10.],
# [11.],
# [12.],
# [13.],
# [14.]]])
Is this what you are looking for?