How to perform Torchvision rotation and flip transformations for semantic segmentation Task?

Hi, I am work on semantic segmentation task on a custom dataset and I want to augment the data using transformations like Flipping, rotating, cropping and resizing.

My input image is RGB image of shape (3,h,w) and my labels are target and masks of shape (h,w) and (n, h,w) respectively, where h is height, w is width of image and n is number of classes in segmentation task.

Snippet from my current code is as follow:

import torchvision.transforms as transforms

seed = random.randint(0, 2**32)
im = self.transforms(im)
target = self.transforms(target)
mask = self.transforms(mask)

and I get the following error

Traceback (most recent call last):
File "", line 227, in <module>  
File "", line 143, in train 
for i, dataBatch in enumerate(segTrainLoader ):  
File "/home/cv_dev/.local/lib/python3.6/site-packages/torch/utils/data/", line 345, in __next__              data = self._next_data()
File "/home/cv_dev/.local/lib/python3.6/site-packages/torch/utils/data/", line 385, in _next_data
 data = self._dataset_fetcher.fetch(index)  # may raise StopIteration 
File "/home/cv_dev/.local/lib/python3.6/site-packages/torch/utils/data/_utils/", line 44, in fetch                data = [self.dataset[idx] for idx in possibly_batched_index]
 File "/home/cv_dev/.local/lib/python3.6/site-packages/torch/utils/data/_utils/", line 44, in <listcomp> 
data = [self.dataset[idx] for idx in possibly_batched_index]  
 File "/home/cv_dev/pose_estimation/rebarnet_vectors/", line 75, in __getitem__                               target = self.transforms(target) 
File "/home/cv_dev/.local/lib/python3.6/site-packages/torchvision/transforms/", line 61, in __call__         img = t(img)  
File "/home/cv_dev/.local/lib/python3.6/site-packages/torchvision/transforms/", line 501, in __call__        return F.hflip(img)
File "/home/cv_dev/.local/lib/python3.6/site-packages/torchvision/transforms/", line 414, in hflip 
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'> 

I have understood that torchvision transform like rotation works on PIL Image. I convert the input image in PIL Image and it worked. However, I do not know how to convert masks into the PIL image given its dimensions are not of a grayscale, RGB, or RGBA image.

Can someone suggest a way to solve this?

Did you check this?

The reply assumes that mask is an RGB or grayscale image and can be loaded using PIL Image. In case of semantic segmentation that is not true. Mask can have arbitrary number of channels

@SatyamGaba I find albumentations faster and easier for such tasks