torch.nn.Upsample layer slow in forwards pass

I’m running a version of StarGAN where the ConvTranspose2d layers which up-sample the signal have been replaced by a resize convolution as described below. For some reason, the forwards pass through the Upsample layers are very slow in comparison to the rest of the network, which is surprising to me, considering that the process doesn’t depend on any matrices of parameters.

layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))

is replaced by

layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
layers.append(nn.ReflectionPad2d(1))
layers.append(nn.Conv2d(curr_dim, curr_dim // 2, kernel_size=3, stride=1, padding=0, bias=False))

The first pass through the Upsample layer has an input size of (256, 256, 256) and takes roughly 9.5s.
The second pass has an input size of (128, 512, 512) and takes roughly 17.5s.
When the same up-sampling is performed by a script I wrote using NumPy, it takes 0.3s and 0.5s respectively.

A colleague suggested that this could be caused by the fact that a GPU is designed to process images with fewer channels than the signals processed in this network, but I find this unsatisfying as the rest of the layers in the network run in a fraction of a second. I’m wondering if anyone else has any thoughts as to the cause?

Are you sure you’re measuring that right? I’m asking because most people don’t.
For integral scaling, it can be faster (and with deterministic backward and all) to use array operations. You might split this into several lines, but to give you an idea, this would upscale a BCHW inp by 2x:

out = inp[:,:, :,None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(inp.size(0), inp.size(1), 2 * inp.size(2), 2* inp.size(3))

Best regards

Thomas

I’m not totally sure that I am measuring the times correctly, but the reality was that the network was taking too long to run with the resize convolution, compared to the original network with the transpose convolution.

Nevertheless, you’re absolutely right, the array operation method is vastly more efficient. The time for a forwards pass through the network decreases from ~65s to ~40s by up-sampling your way.

True, upsampling in pytoch is very slow.