Functional [] spatial dimensions

I am following a U-Net tutorial and I am currently stuck with some interpolation. My x comes as `[256, 16384]`. Which is a batch of 256 with 1 channel and the 128x128 is flatten to 16384.
Thus, I reshape x to `[256, 1, 128, 128].` Then I go for some transformations and my out prior to a reshape is [256, 1, 32, 32]. I reshape it to [256, 1, 32x32] and finally squeeze it to [256, 1296]. My print outs show me that everything until then is correct. Finally, I would like to interpolate back to [256, 16384].

``````  def forward(self, x):
print(f'dim of x is {x.shape}')
src_dims = (x.shape[0], 1, 128, 128)
z = self.encoder(torch.reshape(x, src_dims))
out = self.decoder(z[::-1][0], z[::-1][1:])
out = torch.reshape(out, (out.shape[0], 1, out.shape[2] * out.shape[3]))
out = torch.squeeze(out)
print(f'shape of OUT PRIOR interpolate is {out.shape}')
if self.retain_dim:
out = F.interpolate(out, (x.shape[0], x.shape[1]))
print(f'shape of OUT after squeeze is {out.shape}')
z   = z[0]
z = torch.squeeze(z)
return out, z
``````

My print-outs before interpolate are:

dim of x is torch.Size([256, 16384])
shape of OUT PRIOR interpolate is torch.Size([256, 1296])

so now I assume, if I plug it into F.interpolate, it should interpolate a [256, 1296] tensor to a [256, 16384] tensor.

nevertheless here is this error:

``````
ValueError: Input and output must have the same number of spatial dimensions,
but got input with spatial dimensions of [] and output size of (256, 16384).
Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size
in (o1, o2, ...,oK) format.
``````

`F.interpolate` expects sizes for all spatial dimensions in its `size` argument while it seems you are passing the batch size as well.
If your input has the shape `[256, 16384]` I assume `256` corresponds to the batch size and `16384` to the spatial/temporal dimension. If so, you would need to `unsqueeze` the channel dimension and provide a single value to `size`:

``````out = torch.randn(256, 16384)
out = out.unsqueeze(1)
print(out.shape)
# torch.Size([256, 1, 16384])

out = F.interpolate(out, (1296))
print(out.shape)
# torch.Size([256, 1, 1296])
``````

I have solved it this way, but now I am not sure, whether I do the right thing. I am not sure, whether I have used everything as intended.

``````
def forward(self, x):
# x is of shape (256, 16384) where 256 is batch size and 16384 == 128**2
src_dims = (x.shape[0], 1, 128, 128)
x = torch.reshape(x, src_dims)
# now x is of shape (256, 1, 128, 128) where 256 is batch size, 1 is channel
z = self.encoder(x)
out = self.decoder(z[::-1][0], z[::-1][1:])
# here out is of shape (256, 1, 36, 36)
out = torch.reshape(out, (out.shape[0], 1, out.shape[2] * out.shape[3]))
# out is at this point of shape (256, 1, 1296)
if self.retain_dim:
out = F.interpolate(out, 128*128)
# after interpolate out is of shape (256, 1, 16384)
out = torch.squeeze(out)
# and finally out is of shape (256, 16384)
z   = z[0]