Hi, I am work on semantic segmentation task on a custom dataset and I want to augment the data using transformations like Flipping, rotating, cropping and resizing.
My input image is RGB image of shape (3,h,w) and my labels are target and masks of shape (h,w) and (n, h,w) respectively, where h is height, w is width of image and n is number of classes in segmentation task.
Snippet from my current code is as follow:
import torchvision.transforms as transforms
seed = random.randint(0, 2**32)
self._set_seed(seed)
im = self.transforms(im)
self._set_seed(seed)
target = self.transforms(target)
self._set_seed(seed)
mask = self.transforms(mask)
and I get the following error
Traceback (most recent call last):
File "train.py", line 227, in <module>
train()
File "train.py", line 143, in train
for i, dataBatch in enumerate(segTrainLoader ):
File "/home/cv_dev/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__ data = self._next_data()
File "/home/cv_dev/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 385, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/cv_dev/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/cv_dev/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/cv_dev/pose_estimation/rebarnet_vectors/dataLoader.py", line 75, in __getitem__ target = self.transforms(target)
File "/home/cv_dev/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 61, in __call__ img = t(img)
File "/home/cv_dev/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 501, in __call__ return F.hflip(img)
File "/home/cv_dev/.local/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 414, in hflip
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>
I have understood that torchvision transform like rotation works on PIL Image. I convert the input image in PIL Image and it worked. However, I do not know how to convert masks into the PIL image given its dimensions are not of a grayscale, RGB, or RGBA image.
Can someone suggest a way to solve this?