One way to do this is to implement a subclass of torch.utils.data.Dataset that returns a triple (data, target, index) from its __getitem__ method. Then your loop would be:
Here’s an example class for the STL10 data set. I’d be happy to add the indexes to the datasets and submit a pull request if that would be useful.
class STL10DataSet(Dataset):
"""
STL10 data set. Implemented to be able to get the index of the images
in a mini-batch so that predictions can be associated with their original
image data.
https://github.com/mttk/STL10/blob/master/stl10_input.py
"""
def __init__(self, root_dir, split='train', transform=None):
"""
Args:
root_dir (string): path where stl10_binary.tar.gz has been extracted.
transform (callable, optional): Optional transform to be applied on an
image.
"""
self.root_dir = root_dir
self.split = split
self.transform = transform
with open(os.path.join(self.root_dir, '{}_y.bin'.format(self.split)), 'rb') as f:
self.labels = np.fromfile(f, dtype=np.uint8) - 1
with open(os.path.join(self.root_dir, 'class_names.txt')) as f:
self.classes = f.read().splitlines()
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
with open(os.path.join(self.root_dir, '{}_X.bin'.format(self.split)), 'rb') as f:
c, h, w = 3, 96, 96
size = c*h*w
f.seek(idx*size)
# read bytes from file
image = np.fromfile(f, dtype=np.uint8, count=size)
# reshape to C x H x W
image = np.reshape(image, (c, h, w))
# transpose to H x W x C for ToPILImage transform
image = np.transpose(image, (2, 1, 0))
if self.transform:
image = self.transform(image)
return image, self.labels[idx], idx
That’s right! You can remove the transform=transforms.ToTensor() line when creating the dataset, if you would like to apply PIL transformation in __getitem__.
Or just add it in the constructor with transforms.Compose.
if for anyone else the already posted solutions are not enough:
In torch.utils.data.Dataloader.py in the function “put_indices” add this line at the end of the function: return indices
In the same file, in the function right below “put_indices” called “_process_next_batch” modify the line: self._put_indices()
to be: indices = self._put_indices() # indices contains the indices in the batch.
To have the indices available while iterating the dataloader modify the line in the same file, same function “_process_next_batch” from: return batch
To: return (batch, indices)
You can now know the indices in the training script with:
for data in dataloaders['train']:
(input, target), indices = data
The code in the “Dataloader” should now be like this:
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
self.batches_outstanding += 1
self.send_idx += 1
return indices # Added this line
def _process_next_batch(self, batch):
self.rcvd_idx += 1
# self._put_indices() # Old line
indices = self._put_indices() # Modified line
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
#return batch # Old line
return (batch, indices) # Modified line
To know the correspondences between indices and images you can simply add a line in your training script. I have my datasets defined in a standard way as:
image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x),data_transforms[x]) for x in ['train', 'val']}
So to know the index of all the images in the “train” dataset use:
image_datasets['train'].samples
The output is a list containing the paths of the images and their classes ids. The order is alphabetical and it is obtained by first sorting the folders of the dataset (classes) and then the files inside each folder (images).
Note: to know where the files of a certain module that you are currently using are you can type:
Hi. @simo23 thank you for your proposed solution, but it seems not properly work. If the imgs are prepared in a shuffle way the idx do not correspond to the correct original images. I have tried to not shuffle the data images in my train loader and see the idx of the images of the first batch. They are in order but start from 80, they should start from 0 (e.g if my batch size is 10, the obtained idx are 80,81,… 89). This problem seems also related to the number of workers used in the train loader. If I am using num of workers = 4 the idx start from 80 if I am using num of workers = 1 they start from 20. Is there anyone who has the same problem?
If anyone is still looking for an easy solution, this is what I did. This function takes a dataset class (not an instance) and returns a class that returns a data, target, index tuple from __getitem__:
def dataset_with_indices(cls):
"""
Modifies the given Dataset class to return a tuple data, target, index
instead of just data, target.
"""
def __getitem__(self, index):
data, target = cls.__getitem__(self, index)
return data, target, index
return type(cls.__name__, (cls,), {
'__getitem__': __getitem__,
})
To get all indices, you could just create a range operation using indices = range(len(dataset)).
If you want to get the current indices for the batch, my code example should work.
Do you run into any issues with this code snippet?