Implementing custom backward function for grid_sample

TLDR: is it possible to overwrite the gradient w.r.t. one of the inputs to a torch function?

torch.nn.functional.grid_sample with mode="nearest" returns a gradient for grid that is zero everywhere (as expected).

I want overwrite this behavior and return a custom value for the gradient of grid (e.g., one everywhere or None). I also want to return the default gradient of input. I am trying to write a custom torch.autograd.Function to do this, but I am very confused by the syntax.

Here’s the default behavior:

import torch

input = torch.randn(1, 1, 100, 101, 102, requires_grad=True)
grid = torch.randn(1, 1, 1, 200000, 3, requires_grad=True)

output = torch.nn.functional.grid_sample(
    input=input, grid=grid, mode="nearest", align_corners=False
)

y = output.sum()
y.backward()

input.grad.unique(), grid.grad.unique()
# prints : (tensor([0., 1., 2., 3., 4.]), tensor([0.]))

My interpretation of this is that the gradient for input is the number of times a voxel in input was sampled by grid; the gradient for grid is zero everywhere.

Following @gfox’s adaptation of torch.clamp to have a gradient of one everywhere, I tried to write a custom backward pass for grid_sample as follows:

from torch.cuda.amp import custom_bwd, custom_fwd


class DifferentiableNearestNeighbors(torch.autograd.Function):
    """
    In the forward pass, behaves like `torch.nn.functional.grid_sample`.
    In the backward pass, returns the default gradient for `input`, but
    also returns a gradient for `grid` that is 1 everywhere.
    """

    @staticmethod
    @custom_fwd
    def forward(ctx, input, grid, align_corners):
        ctx.save_for_backward(input, grid)
        return torch.nn.functional.grid_sample(
            input, grid, mode="nearest", align_corners=align_corners
        )

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        input, grid, = ctx.saved_tensors
        return grad_output.clone(), torch.zeros_like(grid), None


def differentiable_nearest_neighbors(input, grid, align_corners):
    return DifferentiableNearestNeighbors.apply(input, grid, align_corners)

However, grad_output is a tensor with the same shape as output, i.e., (1, 1, 1, 1, 200000), but is one everywhere. My questions are then:

  • How do you turn grad_output into the original gradient for input that grid_sample returns?
  • Is it possible to simply return the original gradient for input but somehow customize the gradient for grid?

Hi Vivek!

You can achieve this by using register_hook() on your grid tensor:

import torch
print (torch.__version__)

def grid_sample_with_hook (input, grid):
    _ = grid.register_hook (lambda grad: torch.ones_like (grad))
    return  torch.nn.functional.grid_sample (input, grid, mode = "nearest", align_corners = False)
    
input = torch.randn(1, 1, 100, 101, 102, requires_grad=True)
grid = torch.randn(1, 1, 1, 200000, 3, requires_grad=True)

print ('use grid_sample():')
output = torch.nn.functional.grid_sample(
    input=input, grid=grid, mode="nearest", align_corners=False
)

y = output.sum()
y.backward()

print ('input.grad.unique():', input.grad.unique())
print ('grid.grad.unique(): ', grid.grad.unique())

input.grad = None
grid.grad = None
print ('use grid_sample_with_hook():')
output = grid_sample_with_hook (input, grid)

y = output.sum()
y.backward()

print ('input.grad.unique():', input.grad.unique())
print ('grid.grad.unique(): ', grid.grad.unique())

yielding the following output:

2.3.1
use grid_sample():
input.grad.unique(): tensor([0., 1., 2., 3.])
grid.grad.unique():  tensor([0.])
use grid_sample_with_hook():
input.grad.unique(): tensor([0., 1., 2., 3.])
grid.grad.unique():  tensor([1.])

I don’t think that this is a practical approach. Once you put
grid_sample() inside of an autograd.Function, it, itself, will no
longer be tracked by autograd. I suppose that you could somehow
call grid_sample() twice, once to get autograd’s gradient for input
and then again inside of a Function so that autograd picks up your
modified grid.grad, but doing so would seem round-about and inefficient.

Best.

K. Frank

1 Like