Removing samples after creation of dataloader

I’m trying to create a function that takes in a dataloader and a list of unwanted indexes. The function would then remove all the unwanted samples from the dataloader (or return a new dataloader with the samples removed). The main problem is that the loader will be created before the unwanted_indexes are available.

The function would look something like this:

unwanted_indexes = [1, 8, 48, 947]

loader = remove_unwanted_samples(loader, unwanted_indexes)

Would it be a good idea to change the sampler? Atleast I know that dataloader does not directly allow changing attributes after being initialized. Should I inherit the old loader? or what would be a worthwhile path to go down?

Alternatively I do also have access to the dataset, but this seems like a worse idea due to init being expensive and also needing to recreate the loader after.

changing the sampler seems to be a straightforward method and an easier method for me.

Modifying from here, you can do something like this.

import torch
import random
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
from typing import Optional, Sized, Iterator

class MyRandomSampler(RandomSampler):
    r"""Sample elements randomly. 
    Not everything from RandomSampler is implemented.

    Args:
        data_source (Dataset): dataset to sample from
        forbidden  (Optional[list]): list of forbidden numbers
    """
    data_source: Sized
    forbidden: Optional[list]

    def __init__(self, data_source: Sized, forbidden: Optional[list] = []) -> None:
        super().__init__(data_source)
        self.data_source = data_source
        self.forbidden = forbidden
        self.refill()

    def remove(self, new_forbidden):
        # Remove numbers from the available indices
        for num in new_forbidden:
            if not (num in self.forbidden):
                self.forbidden.append(num)
        self._remove(new_forbidden)

    def _remove(self, to_remove):
        # Remove numbers just for this epoch
        for num in to_remove:
            if num in self.idx:
                self.idx.remove(num)

        self._num_samples = len(self.idx)

    def refill(self):
        # Refill the indices after iterating through the entire DataLoader
        self.idx = list(range(len(self.data_source)))
        self._remove(self.forbidden)

    def __iter__(self) -> Iterator[int]:
        for _ in range(self.num_samples // 32):
            batch = random.sample(self.idx, 32)
            self._remove(batch)
            yield from batch
        yield from random.sample(self.idx, self.num_samples % 32)
        self.refill()


# Fake Dataset to see which indices are being used
ds = torch.arange(50)

sampler = MyRandomSampler(ds, forbidden=[0, 2, 4, 6, 8, 9])

dl = DataLoader(ds, batch_size=10, sampler=sampler)

# See what indices are being used for the whole epoch
for data in dl:
    print(data)

# Remove more
print(dl.sampler.forbidden)
dl.sampler.remove([1,3,5])
print(dl.sampler.forbidden)

# See if it changed
for data in dl:
    print(data)

Hope this helps :smile:

Hi @Matias_Vasquez , I’m trying to implement a similar sampler with a loss_threshold value that can be used to append values to forbidden list. Values below a certain loss are appended to forbidden list and hence are removed from the training. I want to check the list after every epoch for loss and send the higher loss values to training and append lower loss samples back to forbidden list. Hence the structure of forbidden list changes after every epoch. The loss_threshold is the standard deviation of all the losses.

import torch
import random
import statistics
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler,SequentialSampler
from typing import Optional, Sized, Iterator
import numpy as np

class LossSampler(RandomSampler):
    r"""Sample elements randomly. 
    Not everything from RandomSampler is implemented.

    Args:
        data_source (Dataset): dataset to sample from
        forbidden (Optional[list]): list of forbidden numbers
        losses (Optional[list]): list of losses associated with each sample
        loss_threshold (float): minimum loss to exclude samples from the next epoch
    """
    data_source: Sized
    forbidden: Optional[list]
    losses: Optional[list]
    loss_threshold: Optional[float]
    mode: Optional[str]

    def __init__(
        self,
        data_source: Sized,
        forbidden: Optional[list] = [],
        losses: Optional[list] = [],
        loss_threshold: Optional[float] = [],
        mode: Optional[str]='train'
    ) -> None:
        super().__init__(data_source)
        self.data_source = data_source
        self.forbidden = forbidden
        self.losses = losses if losses is not None else [0] * len(data_source)
        self.loss_threshold = loss_threshold
        self.refill()

    def remove(self, loss, loss_thresh):
        # Remove numbers from the available indices
        to_rm=[i for i, loss in enumerate(loss) if loss < loss_thresh]
        self.forbidden.extend(to_rm)
        self._remove(to_rm)


    def _remove(self, to_remove):
        # Remove numbers just for this epoch
        for num in to_remove:
            if num in self.idx:
                self.idx.remove(num)

        self._num_samples = len(self.idx)
    def toggle_mode(self,mode):
        if mode=="train":
            print("Training Mode Initiated")
        elif mode=="validate":
            print("Validation Mode Initiated")
        else:
            print("Invalid Mode Assigned")
            
    def refill(self):
        # Refill the indices after iterating through the entire DataLoader
        self.idx = list(range(len(self.data_source)))
        self._remove(self.forbidden)

    def __iter__(self) -> Iterator[int]:
        for _ in range(self.num_samples//40):
            batch = random.sample(self.idx, 40)
            self._remove(batch)
            yield from batch
        yield from random.sample(self.idx, self.num_samples % 40)
        self.remove(self.losses,self.loss_threshold)
        self.refill()
epochs=4
for epoch in range(epochs):
    print("Epoch:", epoch)
# Fake Dataset to see which indices are being used
    ds = torch.arange(50)

    losses = [random.uniform(0, 1) for _ in range(len(ds))]
    loss_thresh=statistics.stdev(losses) 

    f=set(f)
    f=list(f)
    #recalculated losses for forbidden
    forbidden_losses = [random.uniform(0, 1) for _ in range(len(f))] 

    # print("loss",np.round(losses,2),"\n",ds,'\n')
    # print("forbidden loss",len(forbidden_losses),np.round(forbidden_losses,2),"\n",len(f),f,'\n')
    #setting threshold for further processing
    if len(f)!=0:  #ensuring the forbidden list is not empty
        forbid_dict={}
        f=set(f)
        f=list(f)
        print("set",f)
        #creating dictionary of sample and losses
        for key in f:
            for value in forbidden_losses:
                forbid_dict[key] = value
                forbidden_losses.remove(value)
                f.remove(key)
                break
        #traversing through dictionary and appending the samples above threshold to dataset
        for floss in forbid_dict:
            if forbid_dict[floss]>loss_thresh:
                print(floss)
                losses.append(forbid_dict[floss])
                floss_tensor=torch.tensor([floss], dtype=torch.int8) 
                ds=torch.cat((ds,floss_tensor))
    # Remove more
    sampler.remove(losses, loss_thresh)
    f=sampler.forbidden
    print("forbidden",f)
    #cal

sampler = LossSampler(ds, losses=losses, loss_threshold=loss_thresh)
dl = DataLoader(ds, batch_size=10, sampler=sampler)

# See what indices are being used for the whole epoch
for data in dl:
    print(data)


print("new")
# See if it changed
for data in dl:
    print(data)

Here’s the code I’ve managed till now, but it is not performing as expected, I’m a newbie to pytorch and hence unable to understand the whole sampler code yet. Any leads will be appreciated. Thanks…