For a segmentation use case, you should use the functional API of torchvision.transforms as shown here.
torchvision.transforms