Linear Interpolate along Channel/Z axis?

I have a tensor of size [B, 64, 256, 384]. I want to stretch it softly along the channel axis so it becomes [B, 256, 256, 384]. I tried doing this but i get an error

dpv_refined_predicted = F.interpolate(dpv_refined_predicted, size=[256,256,384], mode='linear')
dpv_refined_predicted = F.upsample(dpv_refined_predicted, size=[256,256,384], mode='linear')
NotImplementedError: Got 4D input, but linear mode needs 3D input

F.intepolate would need to use mode='bilinear' for 4D inputs and could only interpolate the spatial dimensions.
To interpolate in the channel dimension, you could permute the input and output as shown here:

B = 2
x =  torch.randn([B, 64, 256, 384])
x = x.permute(0, 2, 1, 3)
out = F.interpolate(x, [256, 384], mode='bilinear')
out = out.permute(0, 2, 1, 3)
2 Likes