How to retrieve the sample indices of a mini-batch

If anyone is still looking for an easy solution, this is what I did. This function takes a dataset class (not an instance) and returns a class that returns a data, target, index tuple from __getitem__:

def dataset_with_indices(cls):
"""
Modifies the given Dataset class to return a tuple data, target, index
instead of just data, target.
"""

    def __getitem__(self, index):
        data, target = cls.__getitem__(self, index)
        return data, target, index

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })

Then you can use it like so:

MNISTWithIndices = dataset_with_indices(MNIST)
dataset = MNISTWithIndices('~/datasets/mnist')

This at least avoids the need to write a child class for every dataset you want to do this with.

13 Likes