I’m working with pathology WSIs (very large, e.g., 100,000×100,000 px), split into tiles (224×224 RGB). Each WSI can have 500–15,000 tiles. Labels are only at the WSI level, so I need a DataLoader that:
-
Samples a subset of tiles per WSI (e.g., 1,000)
-
Applies augmentations to each tile individually. I am using albumentations due to specific transforms which don’t exist in torchvision.
-
Returns a single tensor of all tiles, forward/backprop happens per WSI
The challenge is CPU utilization. A standard DataLoader parallelizes at the WSI level (num_workers > 0), but each WSI contains many tiles. Processing all tiles serially in __getitem__ underutilizes CPUs, while trying to parallelize tile augmentations inside __getitem__ risks nested parallelism, oversubscription, and deadlocks.
This is my Dataset class:
class WSIDataset(Dataset):
def __init__(self, wsi_list, transform=None, max_tiles=None):
self.wsi_list = wsi_list # list of WSIs
self.transform = transform # augmentation function
self.max_tiles = max_tiles
def __len__(self):
return len(self.wsi_list)
def __getitem__(self, idx):
tiles = self._load_tiles_for_wsi(self.wsi_list[idx])
if self.max_tiles and len(tiles) > self.max_tiles:
tiles = tiles[:self.max_tiles]
# Apply augmentations
augmented = torch.stack([self.transform(tile) for tile in tiles])
return augmented
Do you have any suggestions on best practices for how to parallelize workers inside a PyTorch Dataloader?