Return both original and augmented images

Hello,

Since custom dataset returns one sample for each dict entry, I can’t combine both original and augmented images. Here is a code example

def __getitem__(self, idx):
      name = self.ids[idx]
      mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))
      img_file = list(self.images_dir.glob(name + '.*'))

      mask = self.load(mask_file[0])
      img = self.load(img_file[0])

      img = self.preprocess(img, self.scale, is_mask=False)
      mask = self.preprocess(mask, self.scale, is_mask=True)                     

      if self.transform != None:
          img_aug, mask_aug = self.transform((img, mask))

      if img.flatten().tolist() == img_aug.flatten().tolist():
          return {
              'image': torch.as_tensor(img.copy()).float().contiguous(),
              'mask': torch.as_tensor(mask.copy()).long().contiguous(),
              'filename': str(mask_file[0]).split("/")[-1]
          }

      else:                                                                                          
          img = torch.as_tensor(img.copy()).float().contiguous()                                     
          img_aug = torch.as_tensor(img_aug.copy()).float().contiguous()                             
          mask = torch.as_tensor(mask.copy()).long().contiguous()                                   
          mask_aug = torch.as_tensor(mask_aug.copy()).long().contiguous()                           
          return {                                                                                   
              'image': torch.cat([img, img_aug], dim = 0),                                           
              'mask': torch.cat([mask, mask_aug], dim = 0),                                          
              'filename': str(mask_file[0]).split("/")[-1]                                           
          }
     
dataset = ...
train_dataloader = DataLoader(dataset, shuffle = True, batch_size = 16)

Is there any easy way to combine both images?

I’m not sure I understand the issue correctly, but are you seeing an issue using your code? If not, I assume you would like to change it somehow, so could you explain what the desired outputs should look like?