@moreshud sorry for late reply.
At first, is there an optimal approach of creating labelC using the Pytorch inbuilt built function rather than my naive approach of using the bitwise operator?
You can do it in one line expression like that:
labelA = torch.randint(low=0, high=2, size=(32, 32), dtype=torch.long) # !!! dtype is long here
labelB = torch.randint(low=0, high=2, size=(32, 32), dtype=torch.long)
labelC = torch.clamp(1 - labelA - labelB, 0, 1)
Regarding the change in dtype after transformation, I am not quite sure if this is based on my implementation or its a bug perse.
Here is a working example with torchvision 0.8.1
import torch
from torchvision import transforms
image = torch.rand(1, 32, 32)
labelA = torch.randint(low=0, high=2, size=(32, 32), dtype=torch.long)
labelB = torch.randint(low=0, high=2, size=(32, 32), dtype=torch.long)
labelC = torch.clamp(1 - labelA - labelB, 0, 1)
labels = torch.stack([labelC, labelA, labelB])
print(labels.shape)
image_trans = transforms.Compose([
transforms.Resize((20, 20), interpolation=2),
transforms.Normalize([0.5], [0.5])])
gt_trans = transforms.Compose([
transforms.Resize((20, 20), interpolation=0),
])
transformed_image = image_trans(image)
transformed_labels = gt_trans(labels)
print(transformed_image.shape, transformed_labels.shape)
print(transformed_image.dtype, transformed_labels.dtype)
> torch.Size([3, 32, 32])
> torch.Size([1, 20, 20]) torch.Size([3, 20, 20])
> torch.float32 torch.int64
Hope this helps