Cache batches of weighted samples for faster yielding

Hello,

I have a dataset with 82 million samples. The samples can be categorize either as being inside or outside a mask. The number of inside pixels is much lower than the number of outside pixels. However, I would like to have a balanced distribution of in/out samples in every batch.

At the moment I am implementing a custom WeightedSampler where my __iter__() method looks like the following. I am making it compatible with DDP as well:

    def __iter__(self):
        sample_size = self.sample_size * self.num_replicas
        num_in_mask = int(sample_size * self.weight)
        num_out_mask = sample_size - num_in_mask

        g = torch.Generator()
        if self.rank == 0:
            self.counter += 1
        g.manual_seed(self.seed + self.epoch + self.counter % torch.iinfo(int).max)

        in_mask_sample = torch.randint(len(self.dataset.in_mask), size=(num_in_mask,), generator=self.generator)
        out_mask_sample = torch.randint(len(self.dataset.out_mask), size=(num_out_mask,), generator=self.generator)
        in_mask_indices = self.dataset.in_mask[in_mask_sample]
        out_mask_indices = self.dataset.out_mask[out_mask_sample]
        indices = torch.cat((in_mask_indices, out_mask_indices))
        permuted = torch.randperm(len(indices), generator=self.generator)
        indices = indices[permuted]

        indices = indices[self.rank:len(indices):self.num_replicas]

        return iter(indices.tolist())

self.weight is a probability. For example, 0.5 means my batch has half the samples from inside and half the samples from outside. I have two randint calls and a permute which makes yielding a batch slow. I am working with batch sizes of 128/256 per GPU. I think I can create a flatten array of length 82 million where every slice of 128/256 samples follow a distribution where half of the samples are from in and half are from out. But every attempt I try to this is prohibitively slow. Any idea how to speed up the dataloader?