How does one obtain indicies from a dataloader?

I was looping through the data in a dataloader to create a new dataset. I need to know which data point are which and wanted to get their indicies (e.g. cifar10). How does one do that?

Related:

Train on a fraction of the data set

3 Likes

In the thread you posted is a valid solution:

Here is a small example using CIFAR10 from the other thread:

8 Likes

the thing I can’t generalize is how to do it to any dataloader, they seem to be specific to the corresponding data sets, or am I missing something?

The index is specific to a Dataset and you can return it in the __getitem__ function.
The DataLoader just calls the __getitem__ function from its Dataset and iterates it using the specified batch size.

I don’t think there is an easy way to modify a DataLoader to return the index. At least, I don’t have an idea, sorry.

2 Likes
def __getitem__(self, index):
    data, target = self.cifar10[index]
    return data, target, index
1 Like

Hi @Brando_Miranda, did you solve your problem already? When I saw your last question what came into my mind was to extract the data from cifar10 as a numpy array, like this

data = cifar10.train_data

Checkout the source code here. Now you can work with this array, use your rule to get the indices you want and pass them do SubsetRandomSampler.

1 Like

Thanks! I wish that would have occurred to me to look into…

See here for a general solution: How to retrieve the sample indices of a mini-batch

2 Likes

Here my implementation.

from torch.utils.data import DataLoader, Dataset

class IndexDataset(Dataset):
    def __init__(self, dataset, subset=None):
        self.dataset = dataset
        self.subset = subset
        
    def __getitem__(self, index):
        if self.subset is None:
            data = self.dataset[index]
            real_index = index
        else:
            real_index = self.subset[index]
            data = self.dataset[real_index]
            
        if isinstance(data, dict):
            data["real_index"] = real_index
            return data
        elif isinstance(data, list):
            return [real_index] + data
        else:
            raise NotImplementedError(f"Data type {type(data)} not supported")

    def __len__(self):
        if self.subset is not None:
            return len(self.subset)
        else:
            return len(self.dataset)