How do I properly use the .interpolate function?

I have a tensor, pred which has a .size of torch.Size([8, 28, 161]). I want it to match the shape of outputs, which has a .size of torch.Size([8, 27, 161]), so I’m doing:

pred = torch.nn.functional.interpolate(outputs, size=outputs.size())

But this gives me an error:

  File "train_reconstruction.py", line 204, in main
    pred = torch.nn.functional.interpolate(outputs, size=outputs.size())
  File "/home/shamoon/.local/share/virtualenvs/pytorch-lstm-audio-o2htQyti/lib/python3.6/site-packages/torch/nn/functional.py", line 2510, in interpolate
    return torch._C._nn.upsample_nearest1d(input, _output_size(1))
RuntimeError: It is expected output_size equals to 1, but got size 3

What am I doing incorrectly?

F.interpolate applies the interpolation in the temporal/spatial/volumetric dimensions.
In your case it would accept a single value:

x = torch.randn(8, 28, 161)
out = F.interpolate(x, size=140)
print(out.shape)
> torch.Size([8, 28, 140])

Since you want to interpolate in the channel dimension, you could permute the tensor and apply the interpolation on it:

x = torch.randn(8, 28, 161)
x = x.permute(0, 2, 1)
x = F.interpolate(x, size=27)
x = x.permute(0, 2, 1)
print(x.shape)
> torch.Size([8, 27, 161])
4 Likes

Awesome! Thank you. Works perfectly.

1 Like

I have also encountered this problem, I want to resize a label tensor with size of [2, 3, 3]([B, H, W]) into a size of [2, 5, 5], but the error RuntimeError: It is expected output_size equals to 1, but got size 2 is raised, what is the proper ways to resize a label tensor? @ptrblck

Since you would like to interpolate the two last dimensions, you would have to treat the input as a spatial input which is expected to have 4 dimensions.
You could temporarily unsqueeze the batch dimension, apply the interpolation, and squeeze it afterwards.
Since you’ve mentioned “label tensor” I assume you are working with LongTensors and are thus interested in the mode="nearest" interpolation. If so, this should work:

x = torch.randint(0, 10, (2, 3, 3))
y = F.interpolate(x.unsqueeze(0).float(), size=(5, 5), mode='nearest').squeeze(0).long()
print(y.shape)
> torch.Size([2, 5, 5])
2 Likes

Awesome! Thank you. Works perfectly.