I was looping through the data in a dataloader to create a new dataset. I need to know which data point are which and wanted to get their indicies (e.g. cifar10). How does one do that?
Related:
I was looping through the data in a dataloader to create a new dataset. I need to know which data point are which and wanted to get their indicies (e.g. cifar10). How does one do that?
Related:
In the thread you posted is a valid solution:
Here is a small example using CIFAR10 from the other thread:
the thing I can’t generalize is how to do it to any dataloader, they seem to be specific to the corresponding data sets, or am I missing something?
The index is specific to a Dataset
and you can return it in the __getitem__
function.
The DataLoader
just calls the __getitem__
function from its Dataset
and iterates it using the specified batch size.
I don’t think there is an easy way to modify a DataLoader
to return the index. At least, I don’t have an idea, sorry.
def __getitem__(self, index):
data, target = self.cifar10[index]
return data, target, index
Hi @Brando_Miranda, did you solve your problem already? When I saw your last question what came into my mind was to extract the data from cifar10 as a numpy array, like this
data = cifar10.train_data
Checkout the source code here. Now you can work with this array, use your rule to get the indices you want and pass them do SubsetRandomSampler
.
Thanks! I wish that would have occurred to me to look into…
See here for a general solution: How to retrieve the sample indices of a mini-batch
Here my implementation.
from torch.utils.data import DataLoader, Dataset
class IndexDataset(Dataset):
def __init__(self, dataset, subset=None):
self.dataset = dataset
self.subset = subset
def __getitem__(self, index):
if self.subset is None:
data = self.dataset[index]
real_index = index
else:
real_index = self.subset[index]
data = self.dataset[real_index]
if isinstance(data, dict):
data["real_index"] = real_index
return data
elif isinstance(data, list):
return [real_index] + data
else:
raise NotImplementedError(f"Data type {type(data)} not supported")
def __len__(self):
if self.subset is not None:
return len(self.subset)
else:
return len(self.dataset)