How to collect custom batches

Hi!
I am currently working on a system that requires me to load batches composed of random pixels of a given set of images. More precisely, if there are N images in the dataset and batch size is B, for each epoch I need to collect N batches of size B, each one composed of B random pixels of the corresponding image. E.g. first batch is composed of B randomly selected pixels of the first image, etc.
What would be the recommended way to do this?

I am not very familiar with this domain, but it seems like for each image, you can perform B random crops of the image to get a batch. You can potentially do that within __getitem__ or __iter__ of your custom Dataset (or DataPipe).

I ended up with this solution using a DistributedSampler

def train_dataloader(self):
    """returns a dataloader for training according to hparams

    Returns:
        DataLoader: DataLoader ready to deliver samples for training
    """
    # define a distributed sampler in case we are using multiple GPUs
    if self.hparams.num_gpus>1:
        sampler = torch.utils.data.distributed.DistributedSampler(
            self.train_dataset, shuffle=False)
    # only use the sampler if using multiple GPUs
    return DataLoader(
        self.train_dataset,
        shuffle=False,
        num_workers=self.hparams.num_workers,
        batch_size=self.hparams.batch_size,
        pin_memory=False,
        sampler=sampler if self.hparams.num_gpus > 1 else None)