Torchvision.transfors: How to perform identical transform on both image and target?

A simple search on Google will give you this

1 Like

Oh, my fault.
Thank you for the tips.

Hi jdhao,

Sorry to bother you again.
I’ve checked related issues and prs, but I’m still confused about how to use the identical transform.

So is there any documentation about this part?

Thank you.


PyTorch has an excellent tutorial on data loading. In that tutorial, the author shows how to do transform for both data and target. You can try to mimic the way in the tutorial.


The question hasn’t been answered. The provided references don’t show the ideal practice to do identical transforms for both input image and segmentation label…


Alternatively to the functions from the tutorial, you could use torchvision’s functional API.
Here is a small example for an image and the corresponding mask image:

class MyDataset(Dataset):
    def __init__(self, image_paths, target_paths, train=True):
        self.image_paths = image_paths
        self.target_paths = target_paths

    def transform(self, image, mask):
        # Resize
        resize = transforms.Resize(size=(520, 520))
        image = resize(image)
        mask = resize(mask)

        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(
            image, output_size=(512, 512))
        image = TF.crop(image, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)

        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        # Random vertical flipping
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)

        # Transform to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        return image, mask

    def __getitem__(self, index):
        image =[index])
        mask =[index])
        x, y = self.transform(image, mask)
        return x, y

    def __len__(self):
        return len(self.image_paths)
Data Augmentor for 3D images
How to random crop a image tuple
How make customised dataset for semantic segmentation?
Unet doesnot work after dataaugmentation
Loss problem in net finetuning
Got cuda out of memory while implementing LeNet5
How to crop image and mask in patches?
Multiclass Segmentation
About torchvision.transforms.ToTensor for segmentaion task
Dataloader for semantic segmentation
T.Compose | TypeError: __call__() takes 2 positional arguments but 3 were given
Taget and input must have the same number of elements
Torchvision same random transforms on multi images
Using Histopathological Image Dataset with Ground Truth
How to do the same random crop on 2 images?
*Please Help: Data Loader for Image and Mask*
How to speedup the back-propagation function of AdderNet
Do the same transformations on a list of PIL Image
How to normalize a tensor to 0 mean and 1 variance?
How to process two images with the same augmentation?
Image rotation or image flip and knowledge of new location of certain pixels
Same transform for a paired image?
How to use to concat pictures belonging to two different folders
Label and image are the same when augument data in problem reconstruct image
Data augmentation results network not learning at all
How to load image dataset into pytorch
Apply same transformation to CycleGAN for same patches
Using PyTorch Transformers
Problems about Loading & Processing Multi-bands Data - pytorch
Cant apply both 2 independent transform to image and binary mask
Identical Transformation on Input and Label?
Pre-processing using torchvision.transforms.functional

what is TF in your code?

1 Like

Ah sorry, it’s:

import torchvision.transforms.functional as TF

Assuming both Input and ground truth are images. If we can concatenate input and GT along the axis and then pass the concatenated image through torchvision.transforms.RandomHorizontalFlip() [say]. Then it makes sure that the GT is also flipped when the corresponding input is flipped. I am not sure whether it will work or not practically since I have not tried but theoretically makes sense to me.

The current transformations work with PIL.Images, so that your concatenated image might not be recognized as a valid image. Besides that, it seems to be a good idea.

Thanks for the clarification. I did Random Horizontal flip by generating the random number and suppose it is greater than a threshold just do a normal horizontal flip of both input and GT.

Just to add on this thread - the linked PyTorch tutorial on picture loading is kind of confusing. The author does both import skimage import io, transform, and from torchvision import transforms, utils.

For transform, the authors uses a resize() function and put it into a customized Rescale class. For transforms, the author uses the transforms.Compose function to organize two transformations. But they are from two different modules!

To add to the confusion, torchvision transforms also has its own API called Resize() which is the same name with the one in skimage module.

@ptrblck Where can I get the index from?

Do you mean the index in __getitem__?
You can pass the index directly to your Dataset:

data, target = dataset[0]

or the sampler will do it for you, if you are using a DataLoader.

1 Like

You can also use the standard functions from the library torchvision

from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as tvF
# create custom class transform
class RRC(transforms.RandomResizedCrop):
    def __call__(self, img1, img2):
        assert img1.size == img2.size
        # fix parameter
        i, j, h, w = self.get_params(img1, self.scale, self.ratio)
        # return the image with the same transformation
        return [tvF.resized_crop(img1, i, j, h, w, self.size, self.interpolation), tvF.resized_crop(img2, i, j, h, w, self.size, self.interpolation)]

imgInput ='input.png')
imgTarget ='target.png')   

imgInput, imgTarget = RRC(inputSize, scale=(0.5, 1.0), ratio=(5.0/6.0, 6.0/5.0))(imgInput, imgTarget)

1 Like

@ptrblck: Can it work on 3D data likes DxHxW where D is depth or number of slice? Your example is 2D image HxW. Thanks

torchvision.transforms often rely on PIL as the underlying library, so you would need to transform each slice separately. (At least I’m not aware of PIL methods working on volumetric data)

That being said, it might be faster to write the transformations, e.g. random cropping, manually and apply them directly on the tensor data.

This code for multiple layers

# random resized crop for two images
class RRC(transforms.RandomResizedCrop):
    def __call__(self, imgs):
            img (PIL Image): Image to be cropped and resized.

            PIL Image: Randomly cropped and resized image.

        for im in range(1, len(imgs)):
            assert imgs[im].size == imgs[0].size

        i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)

        for imgCount in range(len(imgs)):
            imgs[imgCount] = tvF.resized_crop(imgs[imgCount], i, j, h, w, self.size, self.interpolation)

        return imgs

You might want to try batchgenerators from the MIC-DKFZ for your 3d Data. It also works multiprocessed and natively supports 2D and 3D.