Custom Sampler to sample "subbatches" within class in extra dimension

I have some questions regarding custom samplers.

First some background:
I have some large images, where each image is devided into many smaller tiles. The number of tiles varies between images. I already build a custom dataset from torchvision.datasets.ImageFolder, which checks the class and the parent image (called parent slide in my use case) of each tile on initialization.
Now what i would like to do, is to have an extra “subbatch” dimension in the Data i get from the Dataloader, where all samples within that dimension are either:

  • from the same class, or
  • from the same slide

So a batch should have the shape [batch_size, subbatch_size, channels, x_resolution, y_resolution].

thanks to the answer of the almighty @ptrblck in this Thread:,
i was already able to write a custom sampler, which should be able to sample the correct indexes within such “subbatches”. I solved it a little differently from whats done in the Thread, but it pointed me in the right direction. Here’s the code:

class TileSubBatchSampler(
    def __init__(self, subbatch_size, tile_image_dataset, mode='class', shuffle=True):
        self.subbatch_size = subbatch_size
        self.ds = tile_image_dataset
        self.mode = mode
        self.shuffle = shuffle
        if mode == 'slide':
            self.subs_idx = [torch.tensor(idx) for idx in self.ds.parent_slide.values()]
        elif mode == 'class':
            self.subs_idx = [torch.tensor(idx) for idx in self.ds.idx_in_class.values()]
    def get_subbatched_sample_indexes(self):
        idx_per_sub = [sub_idx[torch.randperm(len(sub_idx))] for sub_idx in self.subs_idx] if self.shuffle else self.subs_idx
        subbatches_idx =[[sub_idx[i:i + self.subbatch_size][None, :] for i in range(0, len(sub_idx), self.subbatch_size)
                                               if len(sub_idx[i:i + self.subbatch_size]) == self.subbatch_size], dim=0)
                                    for sub_idx in idx_per_sub if len(sub_idx) > self.subbatch_size], dim=0)
        if self.shuffle:
            subbatches_idx = subbatches_idx[torch.randperm(subbatches_idx.shape[0]), :]
        return subbatches_idx
    def __iter__(self):
        subbatches_idx = self.get_subbatched_sample_indexes()
        subbatches_idx = [subbatches_idx[n, :] for n in range(subbatches_idx.shape[0])]
        return iter(subbatches_idx)
    def __len__(self):
        return sum([len(sub_idx) // self.subbatch_size for sub_idx in self.subs_idx])

So on each iteration, it returns an torch.tensor with subbatch_size many indices, where all indices are from the same class or from the same parent_slide. Just to clarify: ds.parent_slide is dict “slide_number :[list of indices in slide]” and ds.idx_in_class is a dict “class: [list of indices in class]”

Now heres my actual question:

What else do i need to do to make this work? Does this just work out of the box when i pass the sampler to the DataLoader? I couldn’t try yet, but i guess not? Do i have to build a custom BatchSampler as well, where i flatten the batch and subbatch dimension into one larger batch dim and reshape my batches back into my wanted shape further down the line?

Or is that whats the collate_fn can be used for, which is passed to the DataLoader?

Also, i would like to run this with DistributedDataParallel later, so i guess i have to turn it into a DistributedSampler as well? In that case, i can probably look here Since my classes are imbalanced anyway, i could also take over the weighting part, which would have to be adapted to sample only from the subbatches, not the samples itself, obviously.
But when i have to build my own BatchSampler, which Sampler needs to be the DistributedSampler?

It’s rather complicated to derive this from the DataLoader code, since it is rather involved, so any tips would be much appreciated!

I think it depends on the actual shape of each batch.
Given that each batch has the shape [batch_size, subbatch_size, channels, x_resolution, y_resolution], are the sizes in each dimension equal for each batch or would e.g. subbatch_size be different?
In the latter case, you could either flatten the subbatch patches into the batch dimension and use the default collate_fn or, if that’s not desired, you could write a custom collate_fn.
However, this also depends on the expected model input. I.e. would it work, if you put all patches into the batch dimension or do you want to process these patches somehow differently?

The DistributedSampler is used to split the dataset into chunks, so that each process would only load its subset. Once your custom sampler works fine, you could adapt the split logic from the DistributedSampler.

batch shapes will be the same across all bachtes durint training.

Yea, you made me realize that i actually dont need to get the data in the shape [batch_size, subbatch_size, channels, x_resolution, y_resolution]. [batch_size*subbatch_size, channels, x_resolution, y_resolution] would be fine, since my model will consume the batches like that anyway (it acts only on the tiles). Just my loss and maybe some other metrics depend on the subbatched predictions. So flattening and using the default collate_fn would be a good approach as well.

After some testing, i made it work by making the __get_item__ method of my dataset accept the subbatched indices of my Sampler and return subbatches. I guess i might sacrifice a tiny bit of performance this way when batch_size % num_workers != 0, as now each worker will collect subbatches instead of individual samples? But i assume it will be no real issue, since the dataloader does prefetching anyway…

This way i don’t need an additional custom BatchSampler and also can inferre the subbatch_size from the data input to my model directly and dont need to pass it as an argument. Since my model needs to be able to accept any (or no) subbatch_size after training, that’s my preffered way.

So now all i need to do is distributify my sampler and im good to go!

1 Like