How to retrieve the sample indices of a mini-batch

i means get all the indices that same with the torch.randperm(len(dataset)) in dataloader. i need all the indices to pass to another function. please give some ideas. thanks.

If you need to handle the indices outside of the DataLoader, I would suggest to sample these indices e.g. with torch,randperm manually and use them in a Subset or write a custom sampler and pass it to the DataLoader.

Hi,
I think it’s a good idea and I want to try it, but I don’t know how to write a dataset class, can you give me some examples(eg. cifar10)? thanks very much:)

Here my implementation.

from torch.utils.data import DataLoader, Dataset

class IndexDataset(Dataset):
    def __init__(self, dataset, subset=None):
        self.dataset = dataset
        self.subset = subset
        
    def __getitem__(self, index):
        if self.subset is None:
            data = self.dataset[index]
            real_index = index
        else:
            real_index = self.subset[index]
            data = self.dataset[real_index]
            
        if isinstance(data, dict):
            data["real_index"] = real_index
            return data
        elif isinstance(data, list):
            return [real_index] + data
        else:
            raise NotImplementedError(f"Data type {type(data)} not supported")

    def __len__(self):
        if self.subset is not None:
            return len(self.subset)
        else:
            return len(self.dataset)

I love the elegance and versatility of this answer. Unfortunately it cannot be used in a multiprocessing setup (depending on its particular configuration, I guess): I ended up with a PicklingError (Can't pickle <class …>: attribute lookup … failed); which, I guess, is because of creating the wrapper class on-the-fly.