Hi, I’m trying to use the new torchvision.transforms
v2 API for applying transforms both my images and their masks in a semantic segmentation task.
My images have a shape of (1, H, W) and the ground truth masks have (3, H, W) shapes (they are one hot encoded), as I am trying to classify each pixel into 3 categories, with the last one corresponding to the background.
I apply a RandomRotation()
to the data, and I want the fill value to be 1.0 for the image and the mask to be [0, 0, 1] in the three channels, corresponding to a background mask. I tried doing this via the dictionary method as described in the documentation, as I assumed that if a tuple can be specified for datapoints.Image
then it would work for masks as well.
import torchvision
from torchvision import datapoints
fill_values = {
datapoints.Image: 1.0,
datapoints.Mask: (0.0, 0.0, 1.0),
}
transforms = torchvision.transforms.v2.Compose((
torchvision.transforms.v2.RandomRotation(180, fill=fill_values),
))
ds = MyDataset("mydatafile", transform=transforms)
image, target = ds[1]
This gives me an error:
Traceback (most recent call last):
File "/Data/cnnurop/transform_test.py", line 79, in <module>
main()
File "/Data/cnnurop/transform_test.py", line 68, in main
image, target = ds[1]
~~^^^
File "/Data/cnnurop/transform_test.py", line 39, in __getitem__
img, target = self.transform(img, target)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Data/cnnurop/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Data/cnnurop/venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py", line 51, in forward
sample = transform(sample)
^^^^^^^^^^^^^^^^^
File "/Data/cnnurop/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Data/cnnurop/venv/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 44, in forward
flat_outputs = [
^
File "/Data/cnnurop/venv/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 45, in <listcomp>
self._transform(inpt, params) if needs_transform else inpt
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Data/cnnurop/venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py", line 643, in _transform
fill = self._fill[type(inpt)]
~~~~~~~~~~^^^^^^^^^^^^
KeyError: <class 'torch.Tensor'>
It’s as if the Mask gets unwrapped into a tensor when the transforms get applied.
I tried setting just a regular value not a tuple for the Mask fill value as well, which did not help.
Currently, instead just set fill=1.0
which sets all channels of the mask tensor to 1, and in __getitem__(self, idx)
I manually zero out these values except in the background mask, which while works is a bit cumbersome workaround. Is there a better a way of doing this?
Here is the custom dataset class I use for reference:
class MyDataset(torch.utils.data.Dataset):
def __init__(self, folder: Path, transforms=None):
# Custom file loading here
self.images = ... # Loading from a file
self.masks = ... # Loading form a file
assert len(self.images) == len(self.masks)
self.length = len(self.images)
self.transforms = transforms
def __len__(self):
return self.length
def __getitem__(self, idx):
img = self.images[idx]
masks = self.masks[idx]
target = {
"masks": datapoints.Mask(masks)
}
if self.transforms is not None:
img, target = self.transforms(img, target)
m = target['masks']
if len(m.shape) == 4: # (B, 3, H, W) shape
mask = torch.all(m == 1, dim=1).unsqueeze(dim=1).repeat(1, 3, 1, 1)
mask[:, -1, :, :] = False
m[mask] = 0
elif len(m.shape) == 3: # (3, H, W) shape
all_channels_one = torch.all(m == 1, dim=0).unsqueeze(dim=0).repeat(2, 1, 1)
m[:-1][all_channels_one] = 0
target['masks'] = m
return img, target