Torch nn.interpolate usage

I am trying to use the torch.nn.interpolate to perform resizing of RGB image on the GPU. So, I allocate a RGB/BGR input as follows:

import torch
x = torch.rand(10, 10, 3).cuda()

So, now I want to resize the image to downsample it by a factor of 2 but only in the spatial dimensions. So the final size of the image should be (5, 5, 3). So, I am trying something like:

a = torch.nn.functional.interpolate(x, (5, 5, 3), mode='linear')

However, this fails as it expects the size parameter to be a single element. I tried a few other things but not sure how I can set it up so that the interpolation happens in the non-channel input dimensions.

1 Like

I think the error comes from your data formatting. When using tensors, PyTorch prefers the BCHW (Batch x Channel x Height x Width) format. Therefore, if you want to declare a tensor representing image data, the declaration should instead be x = torch.rand(3, 10, 10).cuda().

Also, the interpolate function requires the input to be in actual BCHW format, and not CHW as the previous would be. Thus, replace the previous by x = torch.rand(1, 3, 10, 10).cuda(), or call the unsqueeze(0) function on x to add a batch dimension (since it’s only 1 image, it will be a batch size of 1, but it is still required by the function).

Lastly, since linear mode interpolation expects 3D-only input, as in the documentation, this will not work with your data, because it already is 3D without even counting the batch dimension. Therefore, you need bilinear mode to make it work, but I’m not sure that this is what you want.

6 Likes

Thanks. That did the trick!

1 Like