I have x of shape [32,16,16,16]. I am trying to use nn.UpSample() to make x of shape [32,32,32,32]. But i keep getting [32,16,32,32]. Can anybody please help on this?
nn.Upsample interpolates the spatial dimensions, not the channel or batch dims.
Since you explicitly also want to upsample the channel dimension (
dim1) you could
unsqueeze an additional dim before upsampling the input and
x = torch.randn(32, 16, 16, 16) up = nn.Upsample(scale_factor=2) out = up(x) print(out.shape) # torch.Size([32, 16, 32, 32]) out = up(x.unsqueeze(1)).squeeze(1) print(out.shape) # torch.Size([32, 32, 32, 32])