By default, data.DataLoader indexes elements of a batch one by one and collates them back into tensors. I have a dataset (subclass of data.Dataset) which can be indexed (efficiently) by slices. For example, the following can be indexed by slices:
class MyDataset(data.Dataset):
def __init__(self, a: Tensor, b: Tensor):
self.a, self.b = a, b
def __len__(self) -> int:
return len(self.a)
def __getitem__(self, i) -> Tuple[Tensor, Tensor]:
return self.a[i], self.b[i]
I would like the DataLoader to take advantage of that. I think I need to use a data.Sampler, but I didn’t get at all how to from the documentation. Any idea ?
Yes, exactly!
Here is a small code snippet to check the passed indices and the returned values:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.arange(100).float().view(-1, 1)
def __getitem__(self, indices):
print('dataset indices {}'.format(indices))
x = self.data[indices]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(
dataset,
# EDIT2: changed to sampler=...
sampler=BatchSampler(RandomSampler(dataset), 2, False),
)
for x in loader:
print('DataLoader loop {}'.format(x))
EDIT: wait, this looks wrong. Let me debug it quickly.
EDIT2: based on this code snippet the indices would be directly passed to __getitem__, if sampler=BatchSampler(...) is used (which is also the case for my example code here).
This seems to be an “edge case” maybe, as it would be similar to disabling automatic batching, but would use the sampler instead.
Thanks again, @ptrblck! I also had to use batch_size=None to prevent the collate function to add an additional dimension. So the final snippet is
import torch
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self):
self.data = torch.rand(100, 5)
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, i) -> torch.Tensor:
print(i)
return self.data[i]
dataset = MyDataset()
sampler = data.BatchSampler(data.RandomSampler(dataset), 2, False)
loader = data.DataLoader(dataset, sampler=sampler, batch_size=None)
for x in loader:
print(x)
By the way, I feel like using the keyword sampler instead batch_sampler for specifically sampling batches is a bit weird. Perhaps, the interface could be improved a little ? Although, if it ain’t broke, don’t fix it
Haha, I had the same feeling, but I’m sure I misunderstood it.
By checking the code again, the usage of sampler=Sampler and batch_sampler=BatchSampler would yield the same behavior inside the Dataset. The difference is only, that your BatchSampler can now yield multiple indices. This workflow can be used, if the Dataset.__getitem__ method should stay the same, while you manipulate the sampler.
However, as shown in our code snippets, you could also pass a BatchSampler to the sampler argument, which would then pass all indices to the __getitem__ method.
While the naming is now a bit confusing, it’s a convenient method to load multiple samples in the __getitem__. I’m also not sure, if this is a “supported” use case, but so far it seems to work fine.
@VitalyFedyunin are we misusing the sampler or does it make sense?