How can I perform an identical transform on both image and target?
For example, in Semantic segmentation and Edge detection where the input image and target ground-truth are both 2D images, one must perform the same transform on both input image and target ground-truth.
A simple search on Google will give you this
Oh, my fault.
Thank you for the tips.
Sorry to bother you again.
I’ve checked related issues and prs, but I’m still confused about how to use the identical transform.
So is there any documentation about this part?
PyTorch has an excellent tutorial on data loading. In that tutorial, the author shows how to do transform for both data and target. You can try to mimic the way in the tutorial.
The question hasn’t been answered. The provided references don’t show the ideal practice to do identical transforms for both input image and segmentation label…
Alternatively to the functions from the tutorial, you could use torchvision’s functional API.
Here is a small example for an image and the corresponding mask image:
def __init__(self, image_paths, target_paths, train=True):
self.image_paths = image_paths
self.target_paths = target_paths
def transform(self, image, mask):
resize = transforms.Resize(size=(520, 520))
image = resize(image)
mask = resize(mask)
# Random crop
i, j, h, w = transforms.RandomCrop.get_params(
image, output_size=(512, 512))
image = TF.crop(image, i, j, h, w)
mask = TF.crop(mask, i, j, h, w)
# Random horizontal flipping
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
# Random vertical flipping
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
# Transform to tensor
image = TF.to_tensor(image)
mask = TF.to_tensor(mask)
return image, mask
def __getitem__(self, index):
image = Image.open(self.image_paths[index])
mask = Image.open(self.target_paths[index])
x, y = self.transform(image, mask)
return x, y
Ah sorry, it’s:
import torchvision.transforms.functional as TF
Assuming both Input and ground truth are images. If we can concatenate input and GT along the axis and then pass the concatenated image through torchvision.transforms.RandomHorizontalFlip() [say]. Then it makes sure that the GT is also flipped when the corresponding input is flipped. I am not sure whether it will work or not practically since I have not tried but theoretically makes sense to me.
The current transformations work with
PIL.Images, so that your concatenated image might not be recognized as a valid image. Besides that, it seems to be a good idea.
Thanks for the clarification. I did Random Horizontal flip by generating the random number and suppose it is greater than a threshold just do a normal horizontal flip of both input and GT.
Just to add on this thread - the linked PyTorch tutorial on picture loading is kind of confusing. The author does both
import skimage import io, transform, and
from torchvision import transforms, utils.
transform, the authors uses a resize() function and put it into a customized Rescale class. For
transforms, the author uses the
transforms.Compose function to organize two transformations. But they are from two different modules!
To add to the confusion, torchvision
transforms also has its own API called
Resize() which is the same name with the one in
@ptrblck Where can I get the index from?
Do you mean the
You can pass the index directly to your
data, target = dataset
or the sampler will do it for you, if you are using a
You can also use the standard functions from the library torchvision
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as tvF
# create custom class transform
def __call__(self, img1, img2):
assert img1.size == img2.size
# fix parameter
i, j, h, w = self.get_params(img1, self.scale, self.ratio)
# return the image with the same transformation
return [tvF.resized_crop(img1, i, j, h, w, self.size, self.interpolation), tvF.resized_crop(img2, i, j, h, w, self.size, self.interpolation)]
imgInput = Image.open('input.png')
imgTarget = Image.open('target.png')
imgInput, imgTarget = RRC(inputSize, scale=(0.5, 1.0), ratio=(5.0/6.0, 6.0/5.0))(imgInput, imgTarget)
@ptrblck: Can it work on 3D data likes DxHxW where D is depth or number of slice? Your example is 2D image HxW. Thanks
torchvision.transforms often rely on
PIL as the underlying library, so you would need to transform each slice separately. (At least I’m not aware of
PIL methods working on volumetric data)
That being said, it might be faster to write the transformations, e.g. random cropping, manually and apply them directly on the tensor data.
This code for multiple layers
# random resized crop for two images
def __call__(self, imgs):
img (PIL Image): Image to be cropped and resized.
PIL Image: Randomly cropped and resized image.
for im in range(1, len(imgs)):
assert imgs[im].size == imgs.size
i, j, h, w = self.get_params(imgs, self.scale, self.ratio)
for imgCount in range(len(imgs)):
imgs[imgCount] = tvF.resized_crop(imgs[imgCount], i, j, h, w, self.size, self.interpolation)