Mask-specific fill value in transforms gives KeyError

Hi, I’m trying to use the new torchvision.transforms v2 API for applying transforms both my images and their masks in a semantic segmentation task.

My images have a shape of (1, H, W) and the ground truth masks have (3, H, W) shapes (they are one hot encoded), as I am trying to classify each pixel into 3 categories, with the last one corresponding to the background.
I apply a RandomRotation() to the data, and I want the fill value to be 1.0 for the image and the mask to be [0, 0, 1] in the three channels, corresponding to a background mask. I tried doing this via the dictionary method as described in the documentation, as I assumed that if a tuple can be specified for datapoints.Image then it would work for masks as well.

import torchvision
from torchvision import datapoints

    fill_values = {
        datapoints.Image: 1.0,
        datapoints.Mask: (0.0, 0.0, 1.0),
    }
    transforms = torchvision.transforms.v2.Compose((
        torchvision.transforms.v2.RandomRotation(180, fill=fill_values),
    ))
    ds = MyDataset("mydatafile", transform=transforms)
    image, target = ds[1]

This gives me an error:

Traceback (most recent call last):
  File "/Data/cnnurop/transform_test.py", line 79, in <module>
    main()
  File "/Data/cnnurop/transform_test.py", line 68, in main
    image, target = ds[1]
                    ~~^^^
  File "/Data/cnnurop/transform_test.py", line 39, in __getitem__
    img, target = self.transform(img, target)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Data/cnnurop/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Data/cnnurop/venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py", line 51, in forward
    sample = transform(sample)
             ^^^^^^^^^^^^^^^^^
  File "/Data/cnnurop/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Data/cnnurop/venv/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 44, in forward
    flat_outputs = [
                   ^
  File "/Data/cnnurop/venv/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 45, in <listcomp>
    self._transform(inpt, params) if needs_transform else inpt
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Data/cnnurop/venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py", line 643, in _transform
    fill = self._fill[type(inpt)]
           ~~~~~~~~~~^^^^^^^^^^^^
KeyError: <class 'torch.Tensor'>

It’s as if the Mask gets unwrapped into a tensor when the transforms get applied.
I tried setting just a regular value not a tuple for the Mask fill value as well, which did not help.
Currently, instead just set fill=1.0 which sets all channels of the mask tensor to 1, and in __getitem__(self, idx) I manually zero out these values except in the background mask, which while works is a bit cumbersome workaround. Is there a better a way of doing this?

Here is the custom dataset class I use for reference:

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, folder: Path, transforms=None):
        # Custom file loading here
        self.images = ...   # Loading from a file
        self.masks = ...  # Loading form a file

        assert len(self.images) == len(self.masks)
        self.length = len(self.images)
        self.transforms = transforms

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        img = self.images[idx]
        masks = self.masks[idx]
        target = {
            "masks": datapoints.Mask(masks)
        }
        if self.transforms is not None:
            img, target = self.transforms(img, target)
        m = target['masks']
        if len(m.shape) == 4:  # (B, 3, H, W) shape
            mask = torch.all(m == 1, dim=1).unsqueeze(dim=1).repeat(1, 3, 1, 1)
            mask[:, -1, :, :] = False
            m[mask] = 0
        elif len(m.shape) == 3:  # (3, H, W) shape
            all_channels_one = torch.all(m == 1, dim=0).unsqueeze(dim=0).repeat(2, 1, 1)
            m[:-1][all_channels_one] = 0
        target['masks'] = m
        return img, target


The issue was not wrapping the image part in datapoints.Image.
Adding img = datapoints.Image(img) to the Dataset’s __getitem__ function solved the KeyError.
Tuples for masks still don’t work it seems however.