Hello,
I have a dataset composed of 3D volumes, small enough to be loaded in memory.
I need to process these volumes a batch of slices at a time (GPU memory restrictions)
What would be a clean way to go about this?
My current solution is subclassing torch.utils.data.Dataset to return one volume at a time, then getting slices of the returned volume manually. But this just feels like a somewhat hacky way to create batches…
Code snipped for the Dataset:
class My_Dataset(Dataset):
def __init__(self, paths):
"""
Args:
paths: list of paths to the volumes
"""
self.paths = paths
self.load_everything_in_memory()
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image, mask = self.images[idx], self.masks[idx]
return image, mask
def load_everything_in_memory(self):
"""
self.images = list of volumes
self.masks = labels for those volumes
"""
images, masks = [], []
for path in self.paths:
# do some loading
# get image, mask
images.append(image)
masks.append(mask)
self.images = images
self.masks = masks
It feels like it should be some better way of doing this.
Thanks in advance