I would like to access the batches created by DataLoader with their indices. Is there an easy function in PyTorch for this?
More precisely, I’d like to say something like:
val_data = torchvision.datasets.ImageFolder(root='./imagenet2012',transform=transform);
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batchSize)
for i in range(len(val_loader)):
inputs, _ = val_loader.__getbatch__(i);
Particularly, see the next function of the DataLoaderIter, you could enumerate self.sample_iter fully before hand (the indices of each mini-batch), and then you can have a function on that iterator that just returns a particular index.