How do I apply same transformations to image and mask?

Trying to implement data augmentation into a semantic segmentation training, I tried to apply some transformations to the same image and mask. If I rotate the image, I need to rotate the mask as well. The thing is RandomRotation, RandomHorizontalFlip, etc. use random seeds.

I read somewhere this seeds are generated at the instantiation of the transforms. I’m trying to figure out how to ensure that both image and mask are transformed in the same way. This is a “simple” example of the workflow I’ve tried, which results in image and mask with different rotations.

import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torchvision.io import read_image
from torchvision.transforms import v2

# Load mask
mask_path = os.path.join("data/train_masks", "0cdf5b5d0ce1_01_mask.gif")
mask = torch.tensor(np.array(Image.open(mask_path), dtype=np.float32), dtype=torch.float32).unsqueeze(0)
mask[mask == 255.0] = 1.0

# Load image
img_path = os.path.join("data/train", "0cdf5b5d0ce1_01.jpg")
image = read_image(img_path)

# Define transforms
transform = v2.Compose([
    v2.Resize((160, 240)),
    v2.RandomRotation(degrees=35)
])

# Apply transforms
image, mask = transform(image,mask)

# Image
plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0))
plt.title('Imagen')
plt.axis('off')

# Mask
plt.subplot(1, 2, 2)
plt.imshow(mask.permute(1, 2, 0), cmap='gray')
plt.title('Mask')
plt.axis('off')

# Show figure
plt.show()

Also, if anyone has a better way of reading a gif file I’m all ears :slight_smile: .

You can either use the functional API as described here or torchvision.transforms.v2 which allows to pass multiple objects as described here .

3 Likes

Whoa, #til

We here always used albumentations for augmentations. I didn’t know that torchvision had an augmentation system.

Could you add debug print statements into the transformation showing which angle is used as it should work?
I.e. something like:

angle = random.uniform(-self.degrees, self.degrees)
print("using angle {} to rotate input".format(angle))
input = F.rotate(input, angle)
print("using angle {} to rotate target".format(angle))
target = F.rotate(target, angle)

It should show the same angle and would thus rotate both images with it.
If that’s the case your plot might show the images in a wrong order or you might be shuffling the images somewhere afterwards.

So you have confirmed that indeed the same angle is used for both rotations and that the plotting is not an issue. In this case the correspondence between input and target seems to be broken somewhere else and we might need a minimal and executable code snippet to reproduce and isolate the issue.