The nn.MaxUnpool2d
layer will direct the gradients to the inputs directly as seen here:
pool = nn.MaxPool2d(kernel_size=2, return_indices=True)
x = torch.randn(1, 1, 4, 4, requires_grad=True)
print(x)
# tensor([[[[-0.0680, -0.7748, -0.1858, 0.4980],
# [-1.9051, -1.8637, -0.3653, 1.4009],
# [-0.5979, 0.5564, 1.4532, 0.9443],
# [ 0.2864, -0.3723, -0.9621, 1.0871]]]], requires_grad=True)
act, idx = pool(x)
print(act)
# tensor([[[[-0.0680, 1.4009],
# [ 0.5564, 1.4532]]]], grad_fn=<MaxPool2DWithIndicesBackward0>)
print(idx)
# tensor([[[[ 0, 7],
# [ 9, 10]]]])
act.retain_grad()
unpool = nn.MaxUnpool2d(2)
out = unpool(act, idx)
print(out)
# tensor([[[[-0.0680, 0.0000, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000, 1.4009],
# [ 0.0000, 0.5564, 1.4532, 0.0000],
# [ 0.0000, 0.0000, 0.0000, 0.0000]]]],
# grad_fn=<MaxUnpool2DBackward0>)
out.mean().backward()
print(act.grad)
# tensor([[[[0.0625, 0.0625],
# [0.0625, 0.0625]]]])
print(x.grad)
# tensor([[[[0.0625, 0.0000, 0.0000, 0.0000],
# [0.0000, 0.0000, 0.0000, 0.0625],
# [0.0000, 0.0625, 0.0625, 0.0000],
# [0.0000, 0.0000, 0.0000, 0.0000]]]])