Removing samples after creation of dataloader

I’m trying to create a function that takes in a dataloader and a list of unwanted indexes. The function would then remove all the unwanted samples from the dataloader (or return a new dataloader with the samples removed). The main problem is that the loader will be created before the unwanted_indexes are available.

The function would look something like this:

unwanted_indexes = [1, 8, 48, 947]

loader = remove_unwanted_samples(loader, unwanted_indexes)

Would it be a good idea to change the sampler? Atleast I know that dataloader does not directly allow changing attributes after being initialized. Should I inherit the old loader? or what would be a worthwhile path to go down?

Alternatively I do also have access to the dataset, but this seems like a worse idea due to init being expensive and also needing to recreate the loader after.

changing the sampler seems to be a straightforward method and an easier method for me.

Modifying from here, you can do something like this.

import torch
import random
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
from typing import Optional, Sized, Iterator

class MyRandomSampler(RandomSampler):
    r"""Sample elements randomly. 
    Not everything from RandomSampler is implemented.

    Args:
        data_source (Dataset): dataset to sample from
        forbidden  (Optional[list]): list of forbidden numbers
    """
    data_source: Sized
    forbidden: Optional[list]

    def __init__(self, data_source: Sized, forbidden: Optional[list] = []) -> None:
        super().__init__(data_source)
        self.data_source = data_source
        self.forbidden = forbidden
        self.refill()

    def remove(self, new_forbidden):
        # Remove numbers from the available indices
        for num in new_forbidden:
            if not (num in self.forbidden):
                self.forbidden.append(num)
        self._remove(new_forbidden)

    def _remove(self, to_remove):
        # Remove numbers just for this epoch
        for num in to_remove:
            if num in self.idx:
                self.idx.remove(num)

        self._num_samples = len(self.idx)

    def refill(self):
        # Refill the indices after iterating through the entire DataLoader
        self.idx = list(range(len(self.data_source)))
        self._remove(self.forbidden)

    def __iter__(self) -> Iterator[int]:
        for _ in range(self.num_samples // 32):
            batch = random.sample(self.idx, 32)
            self._remove(batch)
            yield from batch
        yield from random.sample(self.idx, self.num_samples % 32)
        self.refill()


# Fake Dataset to see which indices are being used
ds = torch.arange(50)

sampler = MyRandomSampler(ds, forbidden=[0, 2, 4, 6, 8, 9])

dl = DataLoader(ds, batch_size=10, sampler=sampler)

# See what indices are being used for the whole epoch
for data in dl:
    print(data)

# Remove more
print(dl.sampler.forbidden)
dl.sampler.remove([1,3,5])
print(dl.sampler.forbidden)

# See if it changed
for data in dl:
    print(data)

Hope this helps :smile: