Loading 3D data

Hello,

I have a dataset composed of 3D volumes, small enough to be loaded in memory.

I need to process these volumes a batch of slices at a time (GPU memory restrictions)

What would be a clean way to go about this?
My current solution is subclassing torch.utils.data.Dataset to return one volume at a time, then getting slices of the returned volume manually. But this just feels like a somewhat hacky way to create batches…

Code snipped for the Dataset:

class My_Dataset(Dataset):

    def __init__(self, paths):
        """
        Args:
        paths: list of paths to the volumes
        """
        self.paths = paths
        self.load_everything_in_memory()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image, mask = self.images[idx], self.masks[idx]
        return image, mask

    def load_everything_in_memory(self):
        """
        self.images = list of volumes
        self.masks = labels for those volumes
        """
        images, masks = [], []
        for path in self.paths:
            # do some loading
            # get image, mask
            images.append(image)
            masks.append(mask)
            
        self.images = images
        self.masks = masks

It feels like it should be some better way of doing this.

Thanks in advance

This sounds like a valid approach.
Alternatively, you could of course create a tensor or list with all slices and use the index to grab a slice and create the desired batch from it, but I guess load_everything_in_memory might then become a bit more complicated.

You could try to use torch.utils.checkpoint to trade compute for memory, but I’m not sure how large your memory requirement is for a single batch.

1 Like

Thanks, for your reply! I suppose its good enough for now as it is, I can always look to optimise later if I really need to.