How to use BatchSampler with __getitem__ dataset

I am trying to call the \_\_getitem__ function of my dataset once per batch due to the cost of each dataset query (on remote).

class Dataset(Dataset):

    def __init__(self):
       ...

    def __len__(self):
        ...

    def __getitem__(self, batch_idx):  ------> here I get only one index
        return self.wiki_df.loc[batch_idx]


loader = DataLoader(
                dataset=dataset,
                batch_sampler=BatchSampler(
                    SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
                num_workers=self.hparams.num_data_workers,
            )

This is the current implementation, which does not work.
Is there a way to get the list of indices in the getitem function of the dataset

1 Like

You could disable automatic batching as described here and use a BatchSampler.
Let me know, if that works for you.

1 Like

Well conceptually yes, But practically I just can’t get my hands around the documentation.
If I set both batchsampler and batchsize to none (in order to turn off automatic batching) how does the system knows my batchsize? how does the __getitem__ gets triggered?

Maybe this code snippet might be helpful:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(100).view(100, 1).float()
        
    def __getitem__(self, index):
        x = self.data[index]
        return x
    
    def __len__(self):
        return len(self.data)

dataset = MyDataset()        
sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(dataset),
    batch_size=10,
    drop_last=False)

loader = DataLoader(
    dataset,
    sampler=sampler)

for data in loader:
    print(data)

The index inside __getitem__ will contain 10 random indices, which are used to create the batch in:

x = self.data[index]

You could replace this with your expensive data loading operation.

Let me know, if that helps.

4 Likes

Yes, it does:)
So not intuitive that batchsampler goes into the sampler parameter.
maybe it’s just me

1 Like

Can we do the same thing with batch_sampler?

Could you explain your question and use case a bit?
In my code snippet I’m already using a BatchSampler.