Can you do backward propagation with nearest neighbor interpolation?

Hi,

I saw some pytorch models were using nearest neighbor interpolation in their upsampling layers (in convolutional decoders for instance). But to my knowledge, the nearest interpolation is not differentiable. Shouldnt the model not being able to learn anything because of that ?

Or does the backward propagation works similarly to a pooling layer in that case ?

Hi Lelouch!

Nearest-neighbor interpolation is differentiable and you can backpropagate
through it as shown in this example:

>>> import torch
>>> torch.__version__
'2.1.1'
>>> s = torch.arange (4.0).reshape (1, 1, 4).requires_grad_()
>>> s
tensor([[[0., 1., 2., 3.]]], requires_grad=True)
>>> t = torch.nn.functional.interpolate (s, size = 8, mode = 'nearest')
>>> t
tensor([[[0., 0., 1., 1., 2., 2., 3., 3.]]],
       grad_fn=<UpsampleNearest1DBackward0>)
>>> t.sum().backward()
>>> s.grad
tensor([[[2., 2., 2., 2.]]])

I’m not quite sure what you are saying here, but I think that this description
is correct.

Best.

K. Frank