How to rotate 2D image using TF.rotate

Hi, I was trying to write an augmentation pipeline for images and masks in a segmentation task and I got stuck with TF.rotation. In particular, I get an error when rotating the 2D mask. Here’s a reproduction of the error:

import torch
import torchvision.transforms.functional as TF

shape_image = (3,512,512)
shape_mask = (512, 512)
image = torch.rand(shape_image)
mask = torch.rand(shape_mask)

TF.rotate(image, 15) # runs fine
TF.rotate(mask, 15)

RuntimeError: grid_sampler(): expected 4D or 5D input and grid with same number of dimensions, but got input with sizes [1, 512, 512] and grid with sizes [1, 512, 512, 2]

I have to say I’m new to pytorch so it may be I’m just missing something. Anyway any help would be much appreciated.

Thanks in advance :slight_smile:

This might be because rotate expects inputs to have at least 3 dimensions and the mask just has two. Can you try mask = torch.unsqueeze(mask, 0) to add a dummy “color” dimension to the mask before rotating it?

1 Like

Thanks, that work! :smiley:

1 Like