How to retrieve the sample indices of a mini-batch

Hi,

if for anyone else the already posted solutions are not enough:

  1. In torch.utils.data.Dataloader.py in the function “put_indices” add this line at the end of the function:
    return indices

  2. In the same file, in the function right below “put_indices” called “_process_next_batch” modify the line:
    self._put_indices()
    to be:
    indices = self._put_indices() # indices contains the indices in the batch.

  3. To have the indices available while iterating the dataloader modify the line in the same file, same function “_process_next_batch” from:
    return batch
    To:
    return (batch, indices)

  4. You can now know the indices in the training script with:

     for data in dataloaders['train']:
         (input, target), indices = data

The code in the “Dataloader” should now be like this:

    def _put_indices(self):
        assert self.batches_outstanding < 2 * self.num_workers
        indices = next(self.sample_iter, None)
        if indices is None:
            return
        self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
        self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
        self.batches_outstanding += 1
        self.send_idx += 1
        return indices # Added this line

    def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        # self._put_indices() # Old line
        indices = self._put_indices() # Modified line
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        #return batch # Old line
        return (batch, indices) # Modified line
  1. To know the correspondences between indices and images you can simply add a line in your training script. I have my datasets defined in a standard way as:

image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x),data_transforms[x]) for x in ['train', 'val']}

So to know the index of all the images in the “train” dataset use:

image_datasets['train'].samples

The output is a list containing the paths of the images and their classes ids. The order is alphabetical and it is obtained by first sorting the folders of the dataset (classes) and then the files inside each folder (images).

Note: to know where the files of a certain module that you are currently using are you can type:

python
import torch.utils.data
torch.utils.data.dataloader.__file__
2 Likes