How to interpolate batch of tensors over channel's dimension


I’m implementing a CNN-VAE with skip-connection layers in encoder and decoder, to implicitly optimize the information flow from input data and latent representation.

I am aware that ResBlock use identity short-cut mapping if the resolution (HxW) and the channel depth is kept unchanged, and otherwise use a convolution in the shortcut to make a appropriate upsampling/downsampling (example : 1x1 convolution if only channel depth is changing).

My question is the following, according to that paper : Huangjie Zheng, they stipulated that a skip-connection short-cut is add in each layer, using non-trainable upsampling / downsampling shortcut functions, that therefore doesn’t increase network complexity.

That means that an interpolation throughout channel is necessary, which is not a feature enable by nn.functional.Interpolate.
What would be a pytorch implementation to do so ? My turn-around is to move to 3D volumetric data with torch.unsqueeze, perform a trilinear interpolation and go back to 2D data. Is there a proper way ?

Thanks for your help

You could permute the activation, treat the channels as a spatial dimension, apply F.interpolate on it, and permute the activation back to its original layout.