Transformation is reducing the number of channel

from torchvision import transforms

M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)
N = torch.randint(low=0, high=2, size=(3, 64, 64), dtype = torch.float)

gt_trans = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((64, 64)),
            transforms.ToTensor()])

print(M.shape)
print(gt_trans(M).shape)

print(N.shape)
print(gt_trans(N).shape)

Why is the number of channels of M reduced to 3 after the transformation? Thanks for your explanation(s).

@ptrblck I will be glad for your assistance.

Mores, I am having two loss functions while training a network and ran into an issue of "RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.". I would like to know if it is okay to pass retain_graph=True only for one of the losses (for proper loss computation) or both knowing that this is computationally expensive.

Apparently, torchvision transformation transforms.ToPILImage() never check if the image has more than 4 channels and silently propagate this type of data to the PIL.Image.fromarray method with mode='RGB'. This method if provided with mode ‘RGB’ again doesn’t check number of channels (though it checks if mode is not provided) and change the number of channels to 3.

2 Likes

@moreshud Since torchvision v0.8.x can work on tensors directly, thus you can perform the following

import torch
from torchvision import transforms

M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)

gt_trans = transforms.Compose([
    transforms.Resize((32, 32)),
])
print(M.shape)
print(gt_trans(M).shape)
> torch.Size([6, 64, 64])
> torch.Size([6, 32, 32])
2 Likes

Why did you remove transforms.ToPILImage() and transforms.ToTensor() as both are needed?

Because, since torchvision v0.8.x can work on tensors directly

But for 2D tensors, the transforms.ToPILImage() and transforms.ToTensor() are needed. Also, for 2D tensor returns float32 while 3D tensor returns float64 after respective transformation.

Thanks for your contribution.

But for 2D tensors, the transforms.ToPILImage() and transforms.ToTensor() are needed.

Yes, I see you point here. Yes, images as tensors should be 3D tensors. In case of 2D tensors, I’d recommend to update the way to construct the input as 3D tensor, instead of 2D tensor. And in this case, I would say that if you deal already with tensors and your custom transformation do not require specifically PIL.Image as input, you can safely remove ToPILImage and ToTensor.

Also, for 2D tensor returns float32 while 3D tensor returns float64 after respective transformation.

Probably, it depends on how do you get the input as tensor. If you wish to change dtype of the tensor, this can be done with ConvertImageDtype, https://pytorch.org/docs/stable/torchvision/transforms.html#transforms-on-torch-tensor-only

Normally, transformations should not alter dtype, for example

import torch
from torchvision import transforms


for dtype in [torch.float32, torch.float64]:

    M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype=dtype)
    gt_trans = transforms.Compose([
        transforms.Resize((32, 32)),
    ])
    res = gt_trans(M)
    assert res.dtype == M.dtype, "{} vs {}".format(res.dtype, M.dtype)

If it is not the case for you use-case and it is related to torchvision, this should be a bug and please file an issue :+1:. Thanks!

1 Like

Thanks for the explanation.

Actually, I am performing a volumetric segmentation using a 2D slice wise approach. The image to be fed into the model has a spatial dimension of [512, 512] and has there one-hot encoded labels of the same dimension.

import torch
from torchvision import transforms

image = torch.randint(low=0, high=2, size=(3, 3), dtype=torch.float64)
labelA = torch.randint(low=0, high=2, size=(3, 3), dtype=torch.float64)
labelB = torch.randint(low=0, high=2, size=(3, 3), dtype=torch.float64)

# labelC is created by me which is the backgound of labelA and label B i.e it zeros where either A or B is 1
labelC = torch.zeros((labelB.shape), dtype=torch.long)

for label in [labelA, labelB]:
    labelC |= label.long()
labelC = labelC^1

labels = torch.cat([labelC.unsqueeze_(0)] + [x.unsqueeze_(0) for x in [labelA, labelB]] , dim = 0)
print(labelA, labelB.long(), labelC)
print(labels.shape)


# print(labels.shape)
image_trans = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((3, 3)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])])

gt_trans = transforms.Compose([
        transforms.Resize((3, 3)),
])
transformed_image =  image_trans(image)
transformed_labels = gt_trans(labels)
print(transformed_image.shape, transformed_labels.shape)
print(transformed_image.dtype, transformed_labels.dtype)

At first, is there an optimal approach of creating labelC using the Pytorch inbuilt built function rather than my naive approach of using the bitwise operator?

Regarding the change in dtype after transformation, I am not quite sure if this is based on my implementation or its a bug perse.

@moreshud sorry for late reply.

At first, is there an optimal approach of creating labelC using the Pytorch inbuilt built function rather than my naive approach of using the bitwise operator?

You can do it in one line expression like that:

labelA = torch.randint(low=0, high=2, size=(32, 32), dtype=torch.long)  # !!! dtype is long here
labelB = torch.randint(low=0, high=2, size=(32, 32), dtype=torch.long)

labelC = torch.clamp(1 - labelA - labelB, 0, 1)

Regarding the change in dtype after transformation, I am not quite sure if this is based on my implementation or its a bug perse.

Here is a working example with torchvision 0.8.1

import torch
from torchvision import transforms

image = torch.rand(1, 32, 32)
labelA = torch.randint(low=0, high=2, size=(32, 32), dtype=torch.long)
labelB = torch.randint(low=0, high=2, size=(32, 32), dtype=torch.long)
labelC = torch.clamp(1 - labelA - labelB, 0, 1)

labels = torch.stack([labelC, labelA, labelB])
print(labels.shape)


image_trans = transforms.Compose([
            transforms.Resize((20, 20), interpolation=2),
            transforms.Normalize([0.5], [0.5])])

gt_trans = transforms.Compose([
        transforms.Resize((20, 20), interpolation=0),
])
transformed_image =  image_trans(image)
transformed_labels = gt_trans(labels)
print(transformed_image.shape, transformed_labels.shape)
print(transformed_image.dtype, transformed_labels.dtype)

> torch.Size([3, 32, 32])
> torch.Size([1, 20, 20]) torch.Size([3, 20, 20])
> torch.float32 torch.int64

Hope this helps

1 Like