Hello,
I have a dataset with 82 million samples. The samples can be categorize either as being inside or outside a mask. The number of inside pixels is much lower than the number of outside pixels. However, I would like to have a balanced distribution of in/out samples in every batch.
At the moment I am implementing a custom WeightedSampler where my __iter__()
method looks like the following. I am making it compatible with DDP as well:
def __iter__(self):
sample_size = self.sample_size * self.num_replicas
num_in_mask = int(sample_size * self.weight)
num_out_mask = sample_size - num_in_mask
g = torch.Generator()
if self.rank == 0:
self.counter += 1
g.manual_seed(self.seed + self.epoch + self.counter % torch.iinfo(int).max)
in_mask_sample = torch.randint(len(self.dataset.in_mask), size=(num_in_mask,), generator=self.generator)
out_mask_sample = torch.randint(len(self.dataset.out_mask), size=(num_out_mask,), generator=self.generator)
in_mask_indices = self.dataset.in_mask[in_mask_sample]
out_mask_indices = self.dataset.out_mask[out_mask_sample]
indices = torch.cat((in_mask_indices, out_mask_indices))
permuted = torch.randperm(len(indices), generator=self.generator)
indices = indices[permuted]
indices = indices[self.rank:len(indices):self.num_replicas]
return iter(indices.tolist())
self.weight
is a probability. For example, 0.5
means my batch has half the samples from inside and half the samples from outside. I have two randint
calls and a permute
which makes yielding a batch slow. I am working with batch sizes of 128/256 per GPU. I think I can create a flatten array of length 82 million where every slice of 128/256 samples follow a distribution where half of the samples are from in and half are from out. But every attempt I try to this is prohibitively slow. Any idea how to speed up the dataloader?