I have a custom dataset that I use to load multi-modal npz files and create samples for my model. The way I do it currently is:
def __getitem__(self, idx):
While True:
source = self.data[idx]
for _ in 10:
target = select_target(self.data)
if (diff(source,target) < .10):
continue
else:
sample = {}
if self.use_m1:
sample['m1'] = torch.stack([
to_tensor(normalize_images(np.transpose(cv2.resize(read_npz(m1_input), dsize=(256, 256), interpolation=cv2.INTER_AREA)[..., :3], (2, 0, 1)), max=255)) for m1_input in m1_dirs
], dim=0)
if self.m2:
sample['m2'] = torch.stack([
to_tensor(normalize_images(np.expand_dims(cv2.resize(sem_image, dsize=(256, 256), interpolation=cv2.INTER_NEAREST), axis=0), max=40)) for m2_input in m2_dirs
], dim=0)
if self.config["transform"] and random.random()<0.5 and self.train:
sample = self.transformations(sample)
else:
if self.m1:
sample['m1'] = torch.cat([m1_tensor for m1_tensor in sample['m1']], dim=0)
if self.m2:
sample['m2'] = torch.cat([m2_tensor for m2_tensor in sample['m2']], dim=0)
return sample
I apply transformations as follows:
def transformations(self,input):
i,j,h,w = self.crop.get_params(input['target_material'][0], scale=(0.7, 1.0), ratio=(1.0, 1.0))
if self.use_modality1:
input["m1"] = torch.cat([self.color_jitter(TF.resized_crop(sample, i, j, h, w, size=(256, 256))) for sample in input['m1']], dim=0)
if self.use_semantic:
input["m2"] = torch.cat([TF.resized_crop(sample, i, j, h,w,size=(self.img_size, self.img_size), interpolation=TF.InterpolationMode.NEAREST) for sample in input['m2']], dim=0)
return input
I have a list of directories that I read from in getitem. I have multiple modalities to load up (6 to be exact) and each modality has a total of 4 images associated with it. I keep getting OOM errors around the 5-6 epochs and I cannot figure out where the issue is. I don’t get CUDA OOM - just seems like memory becomes an issue with this custom dataset.
Any help would be appreciated.
Edit: I’ve done some memory tracing and figured out that the issue is indeed in the dataloader. Memory keeps increasing with every getitem() call.