Hi! I read this section several times and tried to google my question, still cannot figure it out.
My dataset is a tensor that fits in memory, so using batched indexing is much more efficient than indexing items one by one and collating them. Below, I provide two solutions I came up with, but is there a better way?
Option A
In this case I don’t fully utilize worker_count
and pin_memory
. I didn’t measure if this is important in my case, but anyway:
class IndicesDataset(Dataset):
def __init__(self, size):
self.size = size
def __len__(self):
return self.size
def __getitem__(self, index):
return index
for batch_idx in DataLoader(IndicesDataset(size), *args, **kwargs):
x, y = X[batch_idx], Y[batch_idx]
Option B
Make a dataset with __getitem__
that is aware of batches. Then:
dataloader = DataLoader(dataset, batch_size=None, sampler=BatchSampler(...))
where BatchSampler yields lists of indices (which directly contradicts to the meaning of the sampler
argument, I know).