Hi @Matias_Vasquez , I’m trying to implement a similar sampler with a loss_threshold value that can be used to append values to forbidden list. Values below a certain loss are appended to forbidden list and hence are removed from the training. I want to check the list after every epoch for loss and send the higher loss values to training and append lower loss samples back to forbidden list. Hence the structure of forbidden list changes after every epoch. The loss_threshold is the standard deviation of all the losses.
import torch
import random
import statistics
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler,SequentialSampler
from typing import Optional, Sized, Iterator
import numpy as np
class LossSampler(RandomSampler):
r"""Sample elements randomly.
Not everything from RandomSampler is implemented.
Args:
data_source (Dataset): dataset to sample from
forbidden (Optional[list]): list of forbidden numbers
losses (Optional[list]): list of losses associated with each sample
loss_threshold (float): minimum loss to exclude samples from the next epoch
"""
data_source: Sized
forbidden: Optional[list]
losses: Optional[list]
loss_threshold: Optional[float]
mode: Optional[str]
def __init__(
self,
data_source: Sized,
forbidden: Optional[list] = [],
losses: Optional[list] = [],
loss_threshold: Optional[float] = [],
mode: Optional[str]='train'
) -> None:
super().__init__(data_source)
self.data_source = data_source
self.forbidden = forbidden
self.losses = losses if losses is not None else [0] * len(data_source)
self.loss_threshold = loss_threshold
self.refill()
def remove(self, loss, loss_thresh):
# Remove numbers from the available indices
to_rm=[i for i, loss in enumerate(loss) if loss < loss_thresh]
self.forbidden.extend(to_rm)
self._remove(to_rm)
def _remove(self, to_remove):
# Remove numbers just for this epoch
for num in to_remove:
if num in self.idx:
self.idx.remove(num)
self._num_samples = len(self.idx)
def toggle_mode(self,mode):
if mode=="train":
print("Training Mode Initiated")
elif mode=="validate":
print("Validation Mode Initiated")
else:
print("Invalid Mode Assigned")
def refill(self):
# Refill the indices after iterating through the entire DataLoader
self.idx = list(range(len(self.data_source)))
self._remove(self.forbidden)
def __iter__(self) -> Iterator[int]:
for _ in range(self.num_samples//40):
batch = random.sample(self.idx, 40)
self._remove(batch)
yield from batch
yield from random.sample(self.idx, self.num_samples % 40)
self.remove(self.losses,self.loss_threshold)
self.refill()
epochs=4
for epoch in range(epochs):
print("Epoch:", epoch)
# Fake Dataset to see which indices are being used
ds = torch.arange(50)
losses = [random.uniform(0, 1) for _ in range(len(ds))]
loss_thresh=statistics.stdev(losses)
f=set(f)
f=list(f)
#recalculated losses for forbidden
forbidden_losses = [random.uniform(0, 1) for _ in range(len(f))]
# print("loss",np.round(losses,2),"\n",ds,'\n')
# print("forbidden loss",len(forbidden_losses),np.round(forbidden_losses,2),"\n",len(f),f,'\n')
#setting threshold for further processing
if len(f)!=0: #ensuring the forbidden list is not empty
forbid_dict={}
f=set(f)
f=list(f)
print("set",f)
#creating dictionary of sample and losses
for key in f:
for value in forbidden_losses:
forbid_dict[key] = value
forbidden_losses.remove(value)
f.remove(key)
break
#traversing through dictionary and appending the samples above threshold to dataset
for floss in forbid_dict:
if forbid_dict[floss]>loss_thresh:
print(floss)
losses.append(forbid_dict[floss])
floss_tensor=torch.tensor([floss], dtype=torch.int8)
ds=torch.cat((ds,floss_tensor))
# Remove more
sampler.remove(losses, loss_thresh)
f=sampler.forbidden
print("forbidden",f)
#cal
sampler = LossSampler(ds, losses=losses, loss_threshold=loss_thresh)
dl = DataLoader(ds, batch_size=10, sampler=sampler)
# See what indices are being used for the whole epoch
for data in dl:
print(data)
print("new")
# See if it changed
for data in dl:
print(data)
Here’s the code I’ve managed till now, but it is not performing as expected, I’m a newbie to pytorch and hence unable to understand the whole sampler code yet. Any leads will be appreciated. Thanks…