How to make DataLoader use batched indexing?

Hi! I read this section several times and tried to google my question, still cannot figure it out.

My dataset is a tensor that fits in memory, so using batched indexing is much more efficient than indexing items one by one and collating them. Below, I provide two solutions I came up with, but is there a better way?

Option A

In this case I don’t fully utilize worker_count and pin_memory. I didn’t measure if this is important in my case, but anyway:

class IndicesDataset(Dataset):
    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return index

for batch_idx in DataLoader(IndicesDataset(size), *args, **kwargs):
    x, y = X[batch_idx], Y[batch_idx]

Option B

Make a dataset with __getitem__ that is aware of batches. Then:

dataloader = DataLoader(dataset, batch_size=None, sampler=BatchSampler(...))

where BatchSampler yields lists of indices (which directly contradicts to the meaning of the sampler argument, I know).

Both approaches would most likely work, although I think you could remove the DataLoader from the first approach and add shuffling manually, if wanted.

If you have already loaded the data into the RAM, you could directly get the samples via indexing or gather.