What information does nn.functional.interpolate use?

I have a tensor img in PyTorch of size bx2xhxw and want to upsample it using torch.nn.functional.interpolate. But while interpolation I do not wish channel 1 to use information from channel 2. To do this should I do,

img2 = torch.rand(b,2,2*h,2*w) # create a random torch tensor.
img2[:,0,:,:] = nn.functional.interpolate(img[:,0,:,:], [2*h,2*w], mode='bilinear', align_corners=True)
img2[:,1,:,:] = nn.functional.interpolate(img[:,1,:,:], [2*h,2*w], mode='bilinear', align_corners=True)

or simply using

img = nn.functional.interpolate(img, [2*h,2*w], mode='bilinear', align_corners=True)

will solve my purpose?

I think if you pass a 4-D tensor it will be interpolated spatially by default (so channel and batch dimensions are not affected):

import torch
x = torch.randn(16, 16, 32, 32)
y = torch.nn.functional.interpolate(x, size=(64, 64))
y2 = torch.empty(16, 16, 64, 64)
for n in range(16):
    for c in range(16):
        inp = x[n,c,:,:].reshape(1, 1, 32, 32)
        y2[n,c,:,:] = torch.nn.functional.interpolate(inp, size=(64, 64))
print(torch.all(y == y2))
1 Like