Index out of range after using Subset()

Hey,

I am using a custom Dataset class because I need the index in my data. But if I subset the data, I can’t use it anymore. I need to remove parts of the data by index, and I think the getitem might be the reason. But I am really stuck and would love some hints :slight_smile:

import torch
from torch.utils.data import Subset
from torchvision import datasets
N_train = 300


def make_dataset_with_index(
    dataset_class,
):
  
    class DatasetWithIndex(dataset_class):
        def __init__(self, *args, **kwargs):
            super(DatasetWithIndex, self).__init__(*args, **kwargs)

            self.dataset = dataset_class(*args, **kwargs)

        def __getitem__(self, index):
            image, label = super(DatasetWithIndex, self).__getitem__(index)
            return (image, label, index)

        def __len__(self):
            return len(self.dataset)
    return DatasetWithIndex



train_dataset = make_dataset_with_index(
        datasets.MNIST,
    )(
        "/tmp/mnist",
        train=True,
        download=True,
    )


    # subsampling to reduce data
train_dataset, _ = torch.utils.data.random_split(
        train_dataset,
        [N_train, len(train_dataset) - N_train],
        generator=torch.Generator().manual_seed(1),
    )

indices_to_remove = [35845,27522,33254]  # These are indices within the original dataset

new_train_indices = [i for i in train_dataset.indices if i not in indices_to_remove]

# Create new Subsets
new_train_dataset = Subset(train_dataset, new_train_indices)

loader = torch.utils.data.DataLoader(new_train_dataset, batch_size=32, shuffle=False)
for batch in loader:
    print(batch)