I’m trying to customise the RandomSampler subclass to prevent samples below a certain provided loss threshold from going into training. I’m confused about the
__iter__() function’s utility and hence unable to make appropriate modifications to it. The
__iter__() function from source code looks like this:
def __iter__(self) -> Iterator[int]: n = len(self.data_source) if self.generator is None: seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator() generator.manual_seed(seed) else: generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: for _ in range(self.num_samples // n): yield from torch.randperm(n, generator=generator).tolist() yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
I was wondering how the iterator function actually works in choosing the samples each time. Also, what’s the utility of having 32 when replacement is true. The use of yield statements and how num_samples is calculated in this function. What happens when value of num_samples and n is same and how the RandomSampler handles such cases. As in my case, the value of n and num_samples seems to be same which is making it loop over only a single sample one-by-one.
Additionally, I also want to add a validation mode where the forbidden list (samples prevented from going into training) are checked for loss after every epoch and the higher loss samples are sent to training again.