TLDR: is it possible to overwrite the gradient w.r.t. one of the inputs to a torch function?
torch.nn.functional.grid_sample with mode="nearest" returns a gradient for grid that is zero everywhere (as expected).
I want overwrite this behavior and return a custom value for the gradient of grid (e.g., one everywhere or None). I also want to return the default gradient of input. I am trying to write a custom torch.autograd.Function to do this, but I am very confused by the syntax.
Here’s the default behavior:
import torch
input = torch.randn(1, 1, 100, 101, 102, requires_grad=True)
grid = torch.randn(1, 1, 1, 200000, 3, requires_grad=True)
output = torch.nn.functional.grid_sample(
input=input, grid=grid, mode="nearest", align_corners=False
)
y = output.sum()
y.backward()
input.grad.unique(), grid.grad.unique()
# prints : (tensor([0., 1., 2., 3., 4.]), tensor([0.]))
My interpretation of this is that the gradient for input is the number of times a voxel in input was sampled by grid; the gradient for grid is zero everywhere.
Following @gfox’s adaptation of torch.clamp to have a gradient of one everywhere, I tried to write a custom backward pass for grid_sample as follows:
from torch.cuda.amp import custom_bwd, custom_fwd
class DifferentiableNearestNeighbors(torch.autograd.Function):
"""
In the forward pass, behaves like `torch.nn.functional.grid_sample`.
In the backward pass, returns the default gradient for `input`, but
also returns a gradient for `grid` that is 1 everywhere.
"""
@staticmethod
@custom_fwd
def forward(ctx, input, grid, align_corners):
ctx.save_for_backward(input, grid)
return torch.nn.functional.grid_sample(
input, grid, mode="nearest", align_corners=align_corners
)
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input, grid, = ctx.saved_tensors
return grad_output.clone(), torch.zeros_like(grid), None
def differentiable_nearest_neighbors(input, grid, align_corners):
return DifferentiableNearestNeighbors.apply(input, grid, align_corners)
However, grad_output is a tensor with the same shape as output, i.e., (1, 1, 1, 1, 200000), but is one everywhere. My questions are then:
- How do you turn
grad_outputinto the original gradient forinputthatgrid_samplereturns? - Is it possible to simply return the original gradient for
inputbut somehow customize the gradient forgrid?