Subsampling Only Spatial Dimentions in a 5D Tensor

I have a 5D tensor x (frames of a video) and I want to upsample the spatial size (the last two dimensions) of this tensor but when I use upsampling, the last three dimensions of the tensor are upsampled. For upsampling I use the following class:

class Upsample(nn.Module):
    def __init__(self, scale_factor, mode, align_corners=False):
        self.interp = interpolate
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners=align_corners

    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode)
        return x

And for example, the main class that I want to upsample a 5D tensor is as follows (I condensed my code):

class Main(nn.Module):
   def __init__(self):
       super(Main, self).__init__()
       self.upsample = Upsample(scale_factor=2, mode='trilinear')

   def forward(self, x):
        x = self.upsample(x)
        return x

To be clearer, for example by applying upsampling on a tensor of x=(2,4,3,10,20), the outcome based on the aforementioned class is x=(2,4,6,20,40) but I need to have x=(2,4,3,20,40).

What is the problem and how can I solve this?

The issue is caused by the standard layout of a 5D tensor as [batch_size, channels, depth, height, width] and by specifying only a single scale_factor. This module then expects to apply the scale_factor too all 3 dims.
If you want to skip the interpolation in the depth dimension, specify the scale_factor as a tuple:

x = torch.randn(2, 4, 3, 10, 20)

up = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear')
out = up(x)
print(out.shape)
# torch.Size([2, 4, 3, 20, 40])

@ptrblck , Thanks a lot for your answer. Is it possible to include only the batch_size in the subsampling in addition to the spatial dimensions, too?

No, you cannot directly interpolate the batch dimension using the Upsample module or F.interpolate. However, you could permute the input tensor such that the batch dimension would align with a spatial or volumetric dimension, interpolate it, and permute it back.
Note that this is not a typical use case since you are trying to interpolate “between samples”, so check if this is really what you want to do.

@ptrblck , Thanks a lot.