Hey,
I am using a custom Dataset class because I need the index in my data. But if I subset the data, I can’t use it anymore. I need to remove parts of the data by index, and I think the getitem might be the reason. But I am really stuck and would love some hints
import torch
from torch.utils.data import Subset
from torchvision import datasets
N_train = 300
def make_dataset_with_index(
dataset_class,
):
class DatasetWithIndex(dataset_class):
def __init__(self, *args, **kwargs):
super(DatasetWithIndex, self).__init__(*args, **kwargs)
self.dataset = dataset_class(*args, **kwargs)
def __getitem__(self, index):
image, label = super(DatasetWithIndex, self).__getitem__(index)
return (image, label, index)
def __len__(self):
return len(self.dataset)
return DatasetWithIndex
train_dataset = make_dataset_with_index(
datasets.MNIST,
)(
"/tmp/mnist",
train=True,
download=True,
)
# subsampling to reduce data
train_dataset, _ = torch.utils.data.random_split(
train_dataset,
[N_train, len(train_dataset) - N_train],
generator=torch.Generator().manual_seed(1),
)
indices_to_remove = [35845,27522,33254] # These are indices within the original dataset
new_train_indices = [i for i in train_dataset.indices if i not in indices_to_remove]
# Create new Subsets
new_train_dataset = Subset(train_dataset, new_train_indices)
loader = torch.utils.data.DataLoader(new_train_dataset, batch_size=32, shuffle=False)
for batch in loader:
print(batch)