Functional [] spatial dimensions

I am following a U-Net tutorial and I am currently stuck with some interpolation. My x comes as [256, 16384]. Which is a batch of 256 with 1 channel and the 128x128 is flatten to 16384.
Thus, I reshape x to [256, 1, 128, 128]. Then I go for some transformations and my out prior to a reshape is [256, 1, 32, 32]. I reshape it to [256, 1, 32x32] and finally squeeze it to [256, 1296]. My print outs show me that everything until then is correct. Finally, I would like to interpolate back to [256, 16384].

  def forward(self, x):
      print(f'dim of x is {x.shape}')
      src_dims = (x.shape[0], 1, 128, 128)
      z = self.encoder(torch.reshape(x, src_dims))
      out = self.decoder(z[::-1][0], z[::-1][1:])
      out = self.head(out)
      out = torch.reshape(out, (out.shape[0], 1, out.shape[2] * out.shape[3]))
      out = torch.squeeze(out)
      print(f'shape of OUT PRIOR interpolate is {out.shape}')
      if self.retain_dim:
          out = F.interpolate(out, (x.shape[0], x.shape[1]))
      print(f'shape of OUT after squeeze is {out.shape}')
      z   = z[0]
      z = self.head(z)
      z = torch.squeeze(z)
      return out, z

My print-outs before interpolate are:

dim of x is torch.Size([256, 16384])
shape of OUT PRIOR interpolate is torch.Size([256, 1296])

so now I assume, if I plug it into F.interpolate, it should interpolate a [256, 1296] tensor to a [256, 16384] tensor.

nevertheless here is this error:


ValueError: Input and output must have the same number of spatial dimensions, 
but got input with spatial dimensions of [] and output size of (256, 16384). 
Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size 
in (o1, o2, ...,oK) format.

F.interpolate expects sizes for all spatial dimensions in its size argument while it seems you are passing the batch size as well.
If your input has the shape [256, 16384] I assume 256 corresponds to the batch size and 16384 to the spatial/temporal dimension. If so, you would need to unsqueeze the channel dimension and provide a single value to size:

out = torch.randn(256, 16384)
out = out.unsqueeze(1)
print(out.shape)
# torch.Size([256, 1, 16384])

out = F.interpolate(out, (1296))
print(out.shape)
# torch.Size([256, 1, 1296])

I have solved it this way, but now I am not sure, whether I do the right thing. I am not sure, whether I have used everything as intended.


    def forward(self, x):
# x is of shape (256, 16384) where 256 is batch size and 16384 == 128**2
        src_dims = (x.shape[0], 1, 128, 128)
        x = torch.reshape(x, src_dims)
# now x is of shape (256, 1, 128, 128) where 256 is batch size, 1 is channel 
        z = self.encoder(x)
        out = self.decoder(z[::-1][0], z[::-1][1:])
        out = self.head(out)
# here out is of shape (256, 1, 36, 36)
        out = torch.reshape(out, (out.shape[0], 1, out.shape[2] * out.shape[3]))
# out is at this point of shape (256, 1, 1296)
        if self.retain_dim:
            out = F.interpolate(out, 128*128)
# after interpolate out is of shape (256, 1, 16384)
        out = torch.squeeze(out)
# and finally out is of shape (256, 16384)
        z   = z[0]
        z = self.head(z)
        z = torch.squeeze(z)
        return out, z

Your code looks alright as apparently no operation fails and you are also making sure to keep the batch size equal.