Confusion about PyTorch Data Augmentation

I’m currently using a UNet for image regression and want to explore data augmentation. However, it appears to me that pytorch transformations (as implemented below) replace the original image in the dataset. Considering that the total amount of data in the dataset stays the same, won’t the model’s performance be similar? Don’t know if that assumption is correct, though. Also, are the transforms refreshed per epoch?

rotation_angle = 90
rotation_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(degrees=rotation_angle),
    transforms.ToTensor()
])


dataset = FullDataset(dir_input, dir_mask, transform=rotation_transform)

FullDataset getitem() snippit:

if self.transform:
    aug_input = self.transform(input)
    aug_mask = self.transform(mask)
return aug_input, aug_mask

The transformation does not replace the data as it’s applied on the fly for each sample.

No, since you are randomly transforming each sample in every iteration. Depending on the used transformations the likelihood to repeat exactly the same random config (for all applied transforms) might go towards zero.