If anyone is still looking for an easy solution, this is what I did. This function takes a dataset class (not an instance) and returns a class that returns a data, target, index tuple from __getitem__
:
def dataset_with_indices(cls):
"""
Modifies the given Dataset class to return a tuple data, target, index
instead of just data, target.
"""
def __getitem__(self, index):
data, target = cls.__getitem__(self, index)
return data, target, index
return type(cls.__name__, (cls,), {
'__getitem__': __getitem__,
})
Then you can use it like so:
MNISTWithIndices = dataset_with_indices(MNIST)
dataset = MNISTWithIndices('~/datasets/mnist')
This at least avoids the need to write a child class for every dataset you want to do this with.