# How do I properly use the .interpolate function?

I have a tensor, `pred` which has a `.size` of `torch.Size([8, 28, 161])`. I want it to match the shape of `outputs`, which has a `.size` of `torch.Size([8, 27, 161])`, so I’m doing:

``````pred = torch.nn.functional.interpolate(outputs, size=outputs.size())
``````

But this gives me an error:

``````  File "train_reconstruction.py", line 204, in main
pred = torch.nn.functional.interpolate(outputs, size=outputs.size())
File "/home/shamoon/.local/share/virtualenvs/pytorch-lstm-audio-o2htQyti/lib/python3.6/site-packages/torch/nn/functional.py", line 2510, in interpolate
RuntimeError: It is expected output_size equals to 1, but got size 3
``````

What am I doing incorrectly?

`F.interpolate` applies the interpolation in the temporal/spatial/volumetric dimensions.
In your case it would accept a single value:

``````x = torch.randn(8, 28, 161)
out = F.interpolate(x, size=140)
print(out.shape)
> torch.Size([8, 28, 140])
``````

Since you want to interpolate in the channel dimension, you could permute the tensor and apply the interpolation on it:

``````x = torch.randn(8, 28, 161)
x = x.permute(0, 2, 1)
x = F.interpolate(x, size=27)
x = x.permute(0, 2, 1)
print(x.shape)
> torch.Size([8, 27, 161])
``````
3 Likes

Awesome! Thank you. Works perfectly.

1 Like

I have also encountered this problem, I want to resize a label tensor with size of `[2, 3, 3]([B, H, W])` into a size of `[2, 5, 5]`, but the error `RuntimeError: It is expected output_size equals to 1, but got size 2` is raised, what is the proper ways to resize a label tensor? @ptrblck

Since you would like to interpolate the two last dimensions, you would have to treat the input as a spatial input which is expected to have 4 dimensions.
You could temporarily `unsqueeze` the batch dimension, apply the interpolation, and `squeeze` it afterwards.
Since you’ve mentioned “label tensor” I assume you are working with `LongTensor`s and are thus interested in the `mode="nearest"` interpolation. If so, this should work:

``````x = torch.randint(0, 10, (2, 3, 3))
y = F.interpolate(x.unsqueeze(0).float(), size=(5, 5), mode='nearest').squeeze(0).long()
print(y.shape)
> torch.Size([2, 5, 5])
``````
2 Likes

Awesome! Thank you. Works perfectly.