How to retrieve the sample indices of a mini-batch

Many thanks to @ndronen and @smth!!

Did you end up solving your problem? I would also like retrieve the indices of the samples in a mini-batch.

The suggested solution should work. Is it not working in your case?

Yes, I got it working afterwards. It was a bug on my end. But yes, the suggested solution works. Thanks!

Here’s an example class for the STL10 data set. I’d be happy to add the indexes to the datasets and submit a pull request if that would be useful.

class STL10DataSet(Dataset):
"""
STL10 data set. Implemented to be able to get the index of the images
in a mini-batch so that predictions can be associated with their original
image data.

https://github.com/mttk/STL10/blob/master/stl10_input.py
"""

def __init__(self, root_dir, split='train', transform=None):
    """
    Args:
        root_dir (string): path where stl10_binary.tar.gz has been extracted.
        transform (callable, optional): Optional transform to be applied on an
            image.
    """
    self.root_dir = root_dir
    self.split = split
    self.transform = transform
    
    with open(os.path.join(self.root_dir, '{}_y.bin'.format(self.split)), 'rb') as f:
        self.labels = np.fromfile(f, dtype=np.uint8) - 1
        
    with open(os.path.join(self.root_dir, 'class_names.txt')) as f:
        self.classes = f.read().splitlines()
        
def __len__(self):
    return len(self.labels)

def __getitem__(self, idx):
    with open(os.path.join(self.root_dir, '{}_X.bin'.format(self.split)), 'rb') as f:
        c, h, w = 3, 96, 96
        size = c*h*w
        f.seek(idx*size)
        # read bytes from file
        image = np.fromfile(f, dtype=np.uint8, count=size)
        # reshape to C x H x W
        image = np.reshape(image, (c, h, w))
        # transpose to H x W x C for ToPILImage transform
        image = np.transpose(image, (2, 1, 0))
        
        if self.transform:
            image = self.transform(image)
            
        return image, self.labels[idx], idx
2 Likes

can you provide the import statement for Dataset?

sorry Im confused, how does you code change if I wanted to use cifar10?

You could adapt this code:

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.cifar10 = datasets.CIFAR10(root='YOUR_PATH',
                                        download=False,
                                        train=True,
                                        transform=transforms.ToTensor())
        
    def __getitem__(self, index):
        data, target = self.cifar10[index]
        
        # Your transformations here (or set it in CIFAR10)
        
        return data, target, index

    def __len__(self):
        return len(self.cifar10)

dataset = MyDataset()
loader = DataLoader(dataset,
                    batch_size=1,
                    shuffle=True,
                    num_workers=1)

for batch_idx, (data, target, idx) in enumerate(loader):
    print('Batch idx {}, dataset index {}'.format(
        batch_idx, idx))
17 Likes

I think its important to note that:

data, target = self.cifar10[index]

returns floatTensor and not PIL stuff, so one needs to cast to it or numpy for the transforms to work I believe. Thanks for the help :slight_smile:

That’s right! You can remove the transform=transforms.ToTensor() line when creating the dataset, if you would like to apply PIL transformation in __getitem__.
Or just add it in the constructor with transforms.Compose. :wink:

1 Like

what if I told u I had a bunch of neural nets creating the data set :rofl::joy:

wish I was joking

Hi,

if for anyone else the already posted solutions are not enough:

  1. In torch.utils.data.Dataloader.py in the function “put_indices” add this line at the end of the function:
    return indices

  2. In the same file, in the function right below “put_indices” called “_process_next_batch” modify the line:
    self._put_indices()
    to be:
    indices = self._put_indices() # indices contains the indices in the batch.

  3. To have the indices available while iterating the dataloader modify the line in the same file, same function “_process_next_batch” from:
    return batch
    To:
    return (batch, indices)

  4. You can now know the indices in the training script with:

     for data in dataloaders['train']:
         (input, target), indices = data

The code in the “Dataloader” should now be like this:

    def _put_indices(self):
        assert self.batches_outstanding < 2 * self.num_workers
        indices = next(self.sample_iter, None)
        if indices is None:
            return
        self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
        self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
        self.batches_outstanding += 1
        self.send_idx += 1
        return indices # Added this line

    def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        # self._put_indices() # Old line
        indices = self._put_indices() # Modified line
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        #return batch # Old line
        return (batch, indices) # Modified line
  1. To know the correspondences between indices and images you can simply add a line in your training script. I have my datasets defined in a standard way as:

image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x),data_transforms[x]) for x in ['train', 'val']}

So to know the index of all the images in the “train” dataset use:

image_datasets['train'].samples

The output is a list containing the paths of the images and their classes ids. The order is alphabetical and it is obtained by first sorting the folders of the dataset (classes) and then the files inside each folder (images).

Note: to know where the files of a certain module that you are currently using are you can type:

python
import torch.utils.data
torch.utils.data.dataloader.__file__
2 Likes

Hi. @simo23 thank you for your proposed solution, but it seems not properly work. If the imgs are prepared in a shuffle way the idx do not correspond to the correct original images. I have tried to not shuffle the data images in my train loader and see the idx of the images of the first batch. They are in order but start from 80, they should start from 0 :frowning: (e.g if my batch size is 10, the obtained idx are 80,81,… 89). This problem seems also related to the number of workers used in the train loader. If I am using num of workers = 4 the idx start from 80 if I am using num of workers = 1 they start from 20. Is there anyone who has the same problem?

If anyone is still looking for an easy solution, this is what I did. This function takes a dataset class (not an instance) and returns a class that returns a data, target, index tuple from __getitem__:

def dataset_with_indices(cls):
"""
Modifies the given Dataset class to return a tuple data, target, index
instead of just data, target.
"""

    def __getitem__(self, index):
        data, target = cls.__getitem__(self, index)
        return data, target, index

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })

Then you can use it like so:

MNISTWithIndices = dataset_with_indices(MNIST)
dataset = MNISTWithIndices('~/datasets/mnist')

This at least avoids the need to write a child class for every dataset you want to do this with.

13 Likes

hi, i want to retrive all the indices of my dataset, could you give me some advice?

To get all indices, you could just create a range operation using indices = range(len(dataset)).
If you want to get the current indices for the batch, my code example should work.
Do you run into any issues with this code snippet?

i means get all the indices that same with the torch.randperm(len(dataset)) in dataloader. i need all the indices to pass to another function. please give some ideas. thanks.

If you need to handle the indices outside of the DataLoader, I would suggest to sample these indices e.g. with torch,randperm manually and use them in a Subset or write a custom sampler and pass it to the DataLoader.

Hi,
I think it’s a good idea and I want to try it, but I don’t know how to write a dataset class, can you give me some examples(eg. cifar10)? thanks very much:)

Here my implementation.

from torch.utils.data import DataLoader, Dataset

class IndexDataset(Dataset):
    def __init__(self, dataset, subset=None):
        self.dataset = dataset
        self.subset = subset
        
    def __getitem__(self, index):
        if self.subset is None:
            data = self.dataset[index]
            real_index = index
        else:
            real_index = self.subset[index]
            data = self.dataset[real_index]
            
        if isinstance(data, dict):
            data["real_index"] = real_index
            return data
        elif isinstance(data, list):
            return [real_index] + data
        else:
            raise NotImplementedError(f"Data type {type(data)} not supported")

    def __len__(self):
        if self.subset is not None:
            return len(self.subset)
        else:
            return len(self.dataset)