I have a dataset with 82 million samples. The samples can be categorize either as being inside or outside a mask. The number of inside pixels is much lower than the number of outside pixels. However, I would like to have a balanced distribution of in/out samples in every batch.
At the moment I am implementing a custom WeightedSampler where my
__iter__() method looks like the following. I am making it compatible with DDP as well:
def __iter__(self): sample_size = self.sample_size * self.num_replicas num_in_mask = int(sample_size * self.weight) num_out_mask = sample_size - num_in_mask g = torch.Generator() if self.rank == 0: self.counter += 1 g.manual_seed(self.seed + self.epoch + self.counter % torch.iinfo(int).max) in_mask_sample = torch.randint(len(self.dataset.in_mask), size=(num_in_mask,), generator=self.generator) out_mask_sample = torch.randint(len(self.dataset.out_mask), size=(num_out_mask,), generator=self.generator) in_mask_indices = self.dataset.in_mask[in_mask_sample] out_mask_indices = self.dataset.out_mask[out_mask_sample] indices = torch.cat((in_mask_indices, out_mask_indices)) permuted = torch.randperm(len(indices), generator=self.generator) indices = indices[permuted] indices = indices[self.rank:len(indices):self.num_replicas] return iter(indices.tolist())
self.weight is a probability. For example,
0.5 means my batch has half the samples from inside and half the samples from outside. I have two
randint calls and a
permute which makes yielding a batch slow. I am working with batch sizes of 128/256 per GPU. I think I can create a flatten array of length 82 million where every slice of 128/256 samples follow a distribution where half of the samples are from in and half are from out. But every attempt I try to this is prohibitively slow. Any idea how to speed up the dataloader?