Trying to implement data augmentation into a semantic segmentation training, I tried to apply some transformations to the same image and mask. If I rotate the image, I need to rotate the mask as well. The thing is RandomRotation, RandomHorizontalFlip, etc. use random seeds.
I read somewhere this seeds are generated at the instantiation of the transforms. I’m trying to figure out how to ensure that both image and mask are transformed in the same way. This is a “simple” example of the workflow I’ve tried, which results in image and mask with different rotations.
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torchvision.io import read_image
from torchvision.transforms import v2
# Load mask
mask_path = os.path.join("data/train_masks", "0cdf5b5d0ce1_01_mask.gif")
mask = torch.tensor(np.array(Image.open(mask_path), dtype=np.float32), dtype=torch.float32).unsqueeze(0)
mask[mask == 255.0] = 1.0
# Load image
img_path = os.path.join("data/train", "0cdf5b5d0ce1_01.jpg")
image = read_image(img_path)
# Define transforms
transform = v2.Compose([
v2.Resize((160, 240)),
v2.RandomRotation(degrees=35)
])
# Apply transforms
image, mask = transform(image,mask)
# Image
plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0))
plt.title('Imagen')
plt.axis('off')
# Mask
plt.subplot(1, 2, 2)
plt.imshow(mask.permute(1, 2, 0), cmap='gray')
plt.title('Mask')
plt.axis('off')
# Show figure
plt.show()
Also, if anyone has a better way of reading a gif file I’m all ears .