Albumentations in Pytorch: Inconsistent Augmentation for multi-target datasets

I’m using Pytorch and want to perform the data augmentation of my images with Albumentations. My dataset object has two different targets: ‘blurry’ and ‘sharp’. Each instance of both targets needs to have identical changes. When I try to perform the data augmentation with a Dataset object like this:

class ApplyTransform(Dataset):
def init(self, dataset, transformation):
self.dataset = dataset
self.aug = transformation

def len(self):
return (len(self.dataset))

def getitem(self, idx):
sample, target = self.dataset[idx][‘blurry’], self.dataset[idx][‘sharp’]
transformedImgs = self.aug(image=sample, target_image=target)
sample_aug, target_aug = transformedImgs[“image”], transformedImgs[“target_image”]
return {‘blurry’: sample_aug, ‘sharp’: target_aug}

Unfortunately, I receive two images with two different augmentations:

When I try the same without a Dataset object, I receive two images with the identical application of augmentations. Does anybody know how to make it work with a dataset object?

Here is my augmentation pipeline:

augmentation_transform = A.Compose(
[
A.Resize(1024,1024, p=1),
A.HorizontalFlip(p=0.25),
A.Rotate(limit=(-45, 65)),
A.VerticalFlip(p=0.24),
A.RandomContrast(limit=0.3, p=0.15),
A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
A.pytorch.transforms.ToTensorV2(always_apply=True, p=1.0)
],
additional_targets={“target_image”: “image”}
)

I think a simple approach here, albeit more tedious, is to write your own transformation function directly via the functional versions of the transforms e.g.,

def buildmytransform(hflip_prob, vflip_prob, ...):
    def _transform(image):
        if torch.rand(1) < hflip_prob:
            blur_tmp = albumentations.augmentations.functional.hflip(image)
            sharp_tmp = albumentations.augmentations.functional.hflip(image)
        blur_tmp = albumentations.augmentations.functional.rotate(blur_tmp, ...)
        sharp_tmp = albumentations.augmentations.functional.rotate(sharp_tmp, ...)
        if torch.rand(1) < vflip_prob:
        ...
        return blur_tmp, sharp_tmp
    return _transform