i means get all the indices that same with the torch.randperm(len(dataset)) in dataloader. i need all the indices to pass to another function. please give some ideas. thanks.
If you need to handle the indices outside of the DataLoader
, I would suggest to sample these indices e.g. with torch,randperm
manually and use them in a Subset
or write a custom sampler and pass it to the DataLoader
.
Hi,
I think it’s a good idea and I want to try it, but I don’t know how to write a dataset class, can you give me some examples(eg. cifar10)? thanks very much:)
Here my implementation.
from torch.utils.data import DataLoader, Dataset
class IndexDataset(Dataset):
def __init__(self, dataset, subset=None):
self.dataset = dataset
self.subset = subset
def __getitem__(self, index):
if self.subset is None:
data = self.dataset[index]
real_index = index
else:
real_index = self.subset[index]
data = self.dataset[real_index]
if isinstance(data, dict):
data["real_index"] = real_index
return data
elif isinstance(data, list):
return [real_index] + data
else:
raise NotImplementedError(f"Data type {type(data)} not supported")
def __len__(self):
if self.subset is not None:
return len(self.subset)
else:
return len(self.dataset)
I love the elegance and versatility of this answer. Unfortunately it cannot be used in a multiprocessing setup (depending on its particular configuration, I guess): I ended up with a PicklingError
(Can't pickle <class …>: attribute lookup … failed
); which, I guess, is because of creating the wrapper class on-the-fly.