Hybrid 2d Conv+ConvTranspose layer

Let’s say I have a 4-dimensional tensor (batch x channel x time x space). I’d like to downsample the space axis and upsample the time axis. I can achieve this sequentially as follows:

tnsr = torch.randn(2, 3, 4, 5, requires_grad=True)  # batch x channel x time x space
space_downsampler = nn.Conv2d(3, 3, (1, 3), stride=(1, 2), padding=(0, 1))
tnsr = space_downsampler(tnsr)  # downsample space dimension
assert tnsr.shape == (2, 3, 4, 3)
time_upsampler = nn.ConvTranspose2d(3, 3, (3, 1), stride=(2, 1), padding=(1, 0))
tnsr = time_upsampler(tnsr)  # upsample time dimension
assert tnsr.shape == (2, 3, 7, 3)

This solution uses a 1x3 filter with stride (1,2) followed by a 3x1 filter with fractional stride (0.5, 1). But I would prefer to use a single 3x3 filter that has stride (0.5, 2). This would mean that the 3x3 filter is applied with fractional striding along one axis and with normal 2x striding along the other axis.

How might I implement this? One idea that comes to mind would be to preprocess the tensor by “injecting zeros” into the time axis, so that a poor-man’s fractional stride can be achieved using a normal convolution module:

tnsr = torch.randn(2, 3, 4, 5, requires_grad=True)  # batch x channel x time x space
tnsr_expanded = torch.zeros(2, 3, 7, 5)
tnsr_expanded[:, :, ::2] = tnsr
hybrid_conv = nn.Conv2d(3, 3, (3, 3), stride=(1, 2), padding=(1, 1))
tnsr = hybrid_conv(tnsr_expanded)
assert tnsr.shape == (2, 3, 7, 3)

But there is a problem with the above implementation: one cannot backpropagate gradients smoothly through the tnsr_expanded[:, :, ::2] = tnsr assignment (as far as I know). Here is my clumsy alternative, which does support backpropagation:

tnsr = torch.randn(2, 3, 4, 5, requires_grad=True)  # batch x channel x time x space
zeros = torch.zeros(2,3,5)
tnsr_expanded = torch.stack((tnsr[:,:,0],zs,tnsr[:,:,1],zs,tnsr[:,:,2],zs,tnsr[:,:,3]),dim=2)
hybrid_conv = nn.Conv2d(3, 3, (3, 3), stride=(1, 2), padding=(1, 1))
tnsr = hybrid_conv(tnsr_expanded)
assert tnsr.shape == (2, 3, 7, 3)

Any ideas for a more elegant solution?

EDIT:
Nevermind, it appears that I was mistaken about backpropagating gradients through the tnsr_expanded[:, :, ::2] = tnsr assignment: backpropagation works as expected.