Hi! I read this section several times and tried to google my question, still cannot figure it out.
My dataset is a tensor that fits in memory, so using batched indexing is much more efficient than indexing items one by one and collating them. Below, I provide two solutions I came up with, but is there a better way?
In this case I don’t fully utilize
pin_memory. I didn’t measure if this is important in my case, but anyway:
class IndicesDataset(Dataset): def __init__(self, size): self.size = size def __len__(self): return self.size def __getitem__(self, index): return index for batch_idx in DataLoader(IndicesDataset(size), *args, **kwargs): x, y = X[batch_idx], Y[batch_idx]
Make a dataset with
__getitem__ that is aware of batches. Then:
dataloader = DataLoader(dataset, batch_size=None, sampler=BatchSampler(...))
where BatchSampler yields lists of indices (which directly contradicts to the meaning of the
sampler argument, I know).