Dataset: multiple patches from the same CT + shuffle


I have 3D CT images (and the corresponding segmentation masks), from each of which I would like to retrieve a variable number of 3D patches, centered around a ROI (based on the labels of the segmentation masks).

I would like to train on these patches at random, independently of which CT they came from (i.e. shuffle in the DataLoader should be True).

For this purpose, the Dataset's __getitem__ should give me a single patch + patch_mask pair.

I can do this by loading the same full_size CT image every time I want to crop a patch from it, like so:

class myDataset(Dataset):

    def __init__(self, df, size):
        # df is a dataframe with the paths to the full-size CT's and masks, like so:
        #  index |    ID        |     ID_mask       |   label
        #    0    'path_to_CT1'   'path_to_CT1_mask'      1
        #    1    'path_to_CT1'   'path_to_CT1_mask'      2
        #    2    'path_to_CT2'   'path_to_CT2_mask'      1
        #    3    'path_to_CT2'   'path_to_CT2_mask'      2
        #    4    'path_to_CT2'   'path_to_CT2_mask'      3
        #    5    'path_to_CT2'   'path_to_CT2_mask'      4
        #    6    'path_to_CT3'   'path_to_CT3_mask'      1
        #   ...       ....              ....            ....
        self.df = df 
        self.size = size # size of the patches

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, idx): 
        path_to_CT =[idx,'ID']
        path_to_CT_mask =[idx,'ID_mask']
        CT = sitk.ReadImage(path_to_CT) # loads entire CT (512x512x512)
        CT_mask = sitk.ReadImage(path_to_CT_mask) # loads entire CT_mask  (512x512x512)
        label =[idx,'label']  

        # crop a single patch + patch mask around label (ROI) of size e.g. 128x128x128
        patch, patch_mask = crop_roi(CT, CT_mask , label, self.size) 

        return {'imgs': patch, 'masks': patch_mask}

I can also retrieve all of the patches from the same CT at once and store them for example in a dictionary in the __init__ and then access them in the next __getitem__. However, this would either force me to go through all of the patches in order (with no shuffle) or store all of the patches of all of the CT’s in a dictionary which doesn’t work well in terms of memory.

I’ve read in a previous post that I can “stack the crops into the batch dimension”. I’ve tried it, but because I have a variable number of patches per image, I don’t think it works.

My question: is there a way to avoid loading the same CT every time I need a patch from it (considering that the number of patches of each CT can be up to 150)?