Efficient dataloader for pathology whole slide images

I’m working with pathology WSIs (very large, e.g., 100,000×100,000 px), split into tiles (224×224 RGB). Each WSI can have 500–15,000 tiles. Labels are only at the WSI level, so I need a DataLoader that:

  • Samples a subset of tiles per WSI (e.g., 1,000)

  • Applies augmentations to each tile individually. I am using albumentations due to specific transforms which don’t exist in torchvision.

  • Returns a single tensor of all tiles, forward/backprop happens per WSI

The challenge is CPU utilization. A standard DataLoader parallelizes at the WSI level (num_workers > 0), but each WSI contains many tiles. Processing all tiles serially in __getitem__ underutilizes CPUs, while trying to parallelize tile augmentations inside __getitem__ risks nested parallelism, oversubscription, and deadlocks.

This is my Dataset class:

class WSIDataset(Dataset):
    def __init__(self, wsi_list, transform=None, max_tiles=None):
        self.wsi_list = wsi_list           # list of WSIs
        self.transform = transform         # augmentation function
        self.max_tiles = max_tiles

    def __len__(self):
        return len(self.wsi_list)

    def __getitem__(self, idx):
        tiles = self._load_tiles_for_wsi(self.wsi_list[idx])
        
        if self.max_tiles and len(tiles) > self.max_tiles:
            tiles = tiles[:self.max_tiles]
        
        # Apply augmentations
        augmented = torch.stack([self.transform(tile) for tile in tiles])
        return augmented

Do you have any suggestions on best practices for how to parallelize workers inside a PyTorch Dataloader?

I don’t think PyTorch provides anything to parallelize intra-batch processing out of the box, but I may be missing something.

Did you have deadlocks or reduced performance with nested parallelism (parallelizing tile loading, tile augmentations, or both, inside __getitem__ using e.g. multiprocessing yourself)?

If that seems reasonable, you could also pre-apply all possible transforms on all possible tiles, save them to disk, and then, when training, directly load tiles that are already transformed.

Also, I think you could have a small optimization of your __getitem__ by swapping the condition on max_tiles and the loading of the tiles.

def __getitem__(self, idx):
    if self.max_tiles and len(idx) > self.max_tiles:
        idx = idx[:self.max_tiles]
    # Maybe add special case when idx is an int
 
    tiles = self._load_tiles_for_wsi(self.wsi_list[idx])
    ...

This should reduce disk loading time when trying to fetch more than the max number of tiles.

I hope this helps!

1 Like

A bit late to the discussion, but I happen to have built myself a little library for this exact purpose last year :

Feel free to use it or just explore the implementation, it is pretty simple but it worked well for my needs!