I have some questions regarding custom samplers.
First some background:
I have some large images, where each image is devided into many smaller tiles. The number of tiles varies between images. I already build a custom dataset from torchvision.datasets.ImageFolder, which checks the class and the parent image (called parent slide in my use case) of each tile on initialization.
Now what i would like to do, is to have an extra “subbatch” dimension in the Data i get from the Dataloader, where all samples within that dimension are either:
- from the same class, or
- from the same slide
So a batch should have the shape [batch_size, subbatch_size, channels, x_resolution, y_resolution].
thanks to the answer of the almighty @ptrblck in this Thread:
i was already able to write a custom sampler, which should be able to sample the correct indexes within such “subbatches”. I solved it a little differently from whats done in the Thread, but it pointed me in the right direction. Here’s the code:
class TileSubBatchSampler(torch.utils.data.Sampler): def __init__(self, subbatch_size, tile_image_dataset, mode='class', shuffle=True): self.subbatch_size = subbatch_size self.ds = tile_image_dataset self.mode = mode self.shuffle = shuffle if mode == 'slide': self.subs_idx = [torch.tensor(idx) for idx in self.ds.parent_slide.values()] elif mode == 'class': self.subs_idx = [torch.tensor(idx) for idx in self.ds.idx_in_class.values()] def get_subbatched_sample_indexes(self): idx_per_sub = [sub_idx[torch.randperm(len(sub_idx))] for sub_idx in self.subs_idx] if self.shuffle else self.subs_idx subbatches_idx = torch.cat([torch.cat([sub_idx[i:i + self.subbatch_size][None, :] for i in range(0, len(sub_idx), self.subbatch_size) if len(sub_idx[i:i + self.subbatch_size]) == self.subbatch_size], dim=0) for sub_idx in idx_per_sub if len(sub_idx) > self.subbatch_size], dim=0) if self.shuffle: subbatches_idx = subbatches_idx[torch.randperm(subbatches_idx.shape), :] return subbatches_idx def __iter__(self): subbatches_idx = self.get_subbatched_sample_indexes() subbatches_idx = [subbatches_idx[n, :] for n in range(subbatches_idx.shape)] return iter(subbatches_idx) def __len__(self): return sum([len(sub_idx) // self.subbatch_size for sub_idx in self.subs_idx])
So on each iteration, it returns an torch.tensor with subbatch_size many indices, where all indices are from the same class or from the same parent_slide. Just to clarify: ds.parent_slide is dict “slide_number :[list of indices in slide]” and ds.idx_in_class is a dict “class: [list of indices in class]”
Now heres my actual question:
What else do i need to do to make this work? Does this just work out of the box when i pass the sampler to the DataLoader? I couldn’t try yet, but i guess not? Do i have to build a custom BatchSampler as well, where i flatten the batch and subbatch dimension into one larger batch dim and reshape my batches back into my wanted shape further down the line?
Or is that whats the collate_fn can be used for, which is passed to the DataLoader?
Also, i would like to run this with DistributedDataParallel later, so i guess i have to turn it into a DistributedSampler as well? In that case, i can probably look here https://discuss.pytorch.org/t/how-to-use-my-own-sampler-when-i-already-use-distributedsampler/62143/8? Since my classes are imbalanced anyway, i could also take over the weighting part, which would have to be adapted to sample only from the subbatches, not the samples itself, obviously.
But when i have to build my own BatchSampler, which Sampler needs to be the DistributedSampler?
It’s rather complicated to derive this from the DataLoader code, since it is rather involved, so any tips would be much appreciated!