I have a need to use a BatchSampler
that is exactly as described in the Pytorch documentation, yet, I cannot understand how to use the batchsampler with any given dataset.
e.g
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, idx):
return self.ddf[idx] --------> This is as expensive as a batch call
def get_batch(self, batch_idx):
return self.ddf[batch_idx]
my_loader = DataLoader(MyDataset(remote_ddf),
batch_sampler=BatchSampler(Sampler(), batch_size=3))
The thing I do not understand, neither found any example online or in torch docs, is how do I use my get_batch function instead of the __getitem__ function.
Practically, I would even prefer to not implement the __getitem__, and let my custom dataset create the whole batched sample.