I am trying to call the \_\_getitem__ function of my dataset once per batch due to the cost of each dataset query (on remote).
class Dataset(Dataset):
def __init__(self):
...
def __len__(self):
...
def __getitem__(self, batch_idx): ------> here I get only one index
return self.wiki_df.loc[batch_idx]
loader = DataLoader(
dataset=dataset,
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
num_workers=self.hparams.num_data_workers,
)
This is the current implementation, which does not work.
Is there a way to get the list of indices in the getitem function of the dataset
Well conceptually yes, But practically I just can’t get my hands around the documentation.
If I set both batchsampler and batchsize to none (in order to turn off automatic batching) how does the system knows my batchsize? how does the __getitem__ gets triggered?