I have x of shape [32,16,16,16]. I am trying to use nn.UpSample() to make x of shape [32,32,32,32]. But i keep getting [32,16,32,32]. Can anybody please help on this?
nn.Upsample
interpolates the spatial dimensions, not the channel or batch dims.
Since you explicitly also want to upsample the channel dimension (dim1
) you could unsqueeze
an additional dim before upsampling the input and squeeze
it afterwards:
x = torch.randn(32, 16, 16, 16)
up = nn.Upsample(scale_factor=2)
out = up(x)
print(out.shape)
# torch.Size([32, 16, 32, 32])
out = up(x.unsqueeze(1)).squeeze(1)
print(out.shape)
# torch.Size([32, 32, 32, 32])