Custom Dataset Memory Issues

I have a custom dataset that I use to load multi-modal npz files and create samples for my model. The way I do it currently is:

def __getitem__(self, idx):
While True: 
   source = self.data[idx]
   for _ in 10: 
      target = select_target(self.data) 
     if (diff(source,target) < .10):
        continue
     else:
       sample = {}
       if self.use_m1:                      
          sample['m1'] = torch.stack([
              to_tensor(normalize_images(np.transpose(cv2.resize(read_npz(m1_input), dsize=(256, 256), interpolation=cv2.INTER_AREA)[..., :3], (2, 0, 1)), max=255)) for m1_input in m1_dirs
          ], dim=0)
       if self.m2:     
          sample['m2'] = torch.stack([
              to_tensor(normalize_images(np.expand_dims(cv2.resize(sem_image, dsize=(256, 256), interpolation=cv2.INTER_NEAREST), axis=0), max=40)) for m2_input in m2_dirs
          ], dim=0)
                          
       if self.config["transform"] and random.random()<0.5 and self.train:
        sample = self.transformations(sample)
       else: 
         if self.m1:
            sample['m1'] = torch.cat([m1_tensor for m1_tensor in sample['m1']], dim=0)
        if self.m2: 
            sample['m2'] = torch.cat([m2_tensor for m2_tensor in sample['m2']], dim=0)
                    
  return sample

I apply transformations as follows:

def transformations(self,input):
      i,j,h,w = self.crop.get_params(input['target_material'][0], scale=(0.7, 1.0), ratio=(1.0, 1.0))
      if self.use_modality1:
                input["m1"] = torch.cat([self.color_jitter(TF.resized_crop(sample, i, j, h, w, size=(256, 256))) for sample in input['m1']], dim=0)
      if self.use_semantic:
                input["m2"] = torch.cat([TF.resized_crop(sample, i, j, h,w,size=(self.img_size, self.img_size), interpolation=TF.InterpolationMode.NEAREST) for sample in input['m2']], dim=0)
       return input

I have a list of directories that I read from in getitem. I have multiple modalities to load up (6 to be exact) and each modality has a total of 4 images associated with it. I keep getting OOM errors around the 5-6 epochs and I cannot figure out where the issue is. I don’t get CUDA OOM - just seems like memory becomes an issue with this custom dataset.

Any help would be appreciated.

Edit: I’ve done some memory tracing and figured out that the issue is indeed in the dataloader. Memory keeps increasing with every getitem() call.