I am following a U-Net tutorial and I am currently stuck with some interpolation. My x comes as [256, 16384]
. Which is a batch of 256 with 1 channel and the 128x128 is flatten to 16384.
Thus, I reshape x to [256, 1, 128, 128].
Then I go for some transformations and my out prior to a reshape is [256, 1, 32, 32]. I reshape it to [256, 1, 32x32] and finally squeeze it to [256, 1296]. My print outs show me that everything until then is correct. Finally, I would like to interpolate back to [256, 16384].
def forward(self, x):
print(f'dim of x is {x.shape}')
src_dims = (x.shape[0], 1, 128, 128)
z = self.encoder(torch.reshape(x, src_dims))
out = self.decoder(z[::-1][0], z[::-1][1:])
out = self.head(out)
out = torch.reshape(out, (out.shape[0], 1, out.shape[2] * out.shape[3]))
out = torch.squeeze(out)
print(f'shape of OUT PRIOR interpolate is {out.shape}')
if self.retain_dim:
out = F.interpolate(out, (x.shape[0], x.shape[1]))
print(f'shape of OUT after squeeze is {out.shape}')
z = z[0]
z = self.head(z)
z = torch.squeeze(z)
return out, z
My print-outs before interpolate are:
dim of x is torch.Size([256, 16384])
shape of OUT PRIOR interpolate is torch.Size([256, 1296])
so now I assume, if I plug it into F.interpolate, it should interpolate a [256, 1296] tensor to a [256, 16384] tensor.
nevertheless here is this error:
ValueError: Input and output must have the same number of spatial dimensions,
but got input with spatial dimensions of [] and output size of (256, 16384).
Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size
in (o1, o2, ...,oK) format.