Yes, exactly!
Here is a small code snippet to check the passed indices and the returned values:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.arange(100).float().view(-1, 1)
def __getitem__(self, indices):
print('dataset indices {}'.format(indices))
x = self.data[indices]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(
dataset,
# EDIT2: changed to sampler=...
sampler=BatchSampler(RandomSampler(dataset), 2, False),
)
for x in loader:
print('DataLoader loop {}'.format(x))
EDIT: wait, this looks wrong. Let me debug it quickly.
EDIT2: based on this code snippet the indices would be directly passed to __getitem__
, if sampler=BatchSampler(...)
is used (which is also the case for my example code here).
This seems to be an “edge case” maybe, as it would be similar to disabling automatic batching, but would use the sampler instead.