I think the BatchSampler
will make sure to pass all batch indices to your Dataset
's __getitem__
method as seen in this example:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.arange(100).view(100, 1).float()
def __getitem__(self, index):
print(index)
x = self.data[index]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
sampler = torch.utils.data.sampler.BatchSampler(
torch.utils.data.sampler.RandomSampler(dataset),
batch_size=10,
drop_last=False)
loader = DataLoader(
dataset,
sampler=sampler)
for data in loader:
print(data)
If you run it, you’ll see that the index
inside the Dataset.__getitem__
will contain 10 indices, which can be used to slice the data directly.
Let me know, if I misunderstood the question.