How to Get Batch Sample Indices Info in The "__getitem__(self,idx)" Data Loader Instead of Only Getting Single Index Info?

Hi,

I’m currently having a use case of creating custom data loader that can: (i) change batch_size value dynamically during training and (ii) process the data sample with different operation for each different batch_size.

For (i), I successfully accomplish it by creating my custom batch_sampler that returns (a) batch_sample_indices with different number of indices (its len = batch_size). However, I still cannot solve (ii). In order to achieve (ii), I need to know (a), or at least the len of (a) in the __getitem__() method inside the data loader class. Any body can give a direction for this? Or any other approaches to achieve (i) & (ii) is also welcomed.

Thanks in advance!

You can implement both (i) and (ii) like this:

import random

from torch.utils.data import DataLoader, Sampler, RandomSampler, Dataset


class MyBatchSampler(Sampler):
    def __init__(self, sampler, batch_size_list, drop_last):
        self.sampler = sampler
        self.batch_size_list = batch_size_list
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        batch_size = random.choice(self.batch_size_list)
        for idx in self.sampler:
            batch.append((idx, batch_size))
            if len(batch) == batch_size:
                yield batch
                batch = []
                batch_size = random.choice(self.batch_size_list)
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self) -> int:
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]


class MyDataSet(Dataset):
    def __getitem__(self, item):
        index, batch_size = item
        return index, batch_size

    def __len__(self):
        return 10


my_dataset = MyDataSet()
my_sampler = MyBatchSampler(RandomSampler(my_dataset), [1, 2, 3], False)
dataloader = DataLoader(my_dataset, batch_sampler=my_sampler)
for data in dataloader:
    print(data)

1 Like

Thanks a lot! Work like a charm :wink: