Random Sampler iterator function

I’m trying to customise the RandomSampler subclass to prevent samples below a certain provided loss threshold from going into training. I’m confused about the __iter__() function’s utility and hence unable to make appropriate modifications to it. The __iter__() function from source code looks like this:

def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator = self.generator

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
            for _ in range(self.num_samples // n):
                yield from torch.randperm(n, generator=generator).tolist()
            yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

I was wondering how the iterator function actually works in choosing the samples each time. Also, what’s the utility of having 32 when replacement is true. The use of yield functions and how num_samples is calculated in this function. What happens when value of num_samples and n is same and how the RandomSampler handles such cases. As in my case, the value of n and num_samples seems to be same which is making it loop over only a single sample one-by-one.

Additionally, I also want to add a validation mode where the forbidden list (samples prevented from going into training) are checked for loss after every epoch and the higher loss samples are sent to training again.

  1. The constant 32 prevents generating too large of a list at time (it only draws 32 numbers at a time, and draw more if needed), reducing memory usage
  2. The usage of yield creates a generator when the function is called and returns one index at a time
  3. num_samples is computed within the same file:
def num_samples(self) -> int:
    # dataset size might change at runtime
    if self._num_samples is None:
        return len(self.data_source)
    return self._num_samples

Hi, thanks for responding. I have another query that how num_samples differs from n in the code I attached with question.

It simply yields n indices from torch.randperm(n, generator=generator).tolist()