I would like to access the batches created by DataLoader with their indices. Is there an easy function in PyTorch for this?
More precisely, I’d like to say something like:
val_data = torchvision.datasets.ImageFolder(root='./imagenet2012',transform=transform);
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batchSize)
for i in range(len(val_loader)):
inputs, _ = val_loader.__getbatch__(i);
Any comment is much appreciated.
There isn’t a way to do this directly.
However, if you modify this file slightly, it’ll be possible: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py#L176-L201
Particularly, see the
next function of the DataLoaderIter, you could enumerate
self.sample_iter fully before hand (the indices of each mini-batch), and then you can have a function on that iterator that just returns a particular
@hatef Have you successfully implemented that and would you mind sharing your code please? I’m still have difficulty despite @smth 's hint
I’d like to use num_workers > 1 and be able to receive the indices. I’m not so familiar with multiprocessing. Is it sufficient to return indices from _put_indices() , https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py#L290 ?
you can you can use tqdm_notebook pakage to do that:
for i, (train_inputs, train_labels) in tqdm_notebook(enumerate(dl), total=n):
train_inputs, train_labels = to_var(train_inputs), to_var(train_labels)