TLDR: is it possible to overwrite the gradient w.r.t. one of the inputs to a torch function?
torch.nn.functional.grid_sample
with mode="nearest"
returns a gradient for grid
that is zero everywhere (as expected).
I want overwrite this behavior and return a custom value for the gradient of grid
(e.g., one everywhere or None
). I also want to return the default gradient of input
. I am trying to write a custom torch.autograd.Function
to do this, but I am very confused by the syntax.
Here’s the default behavior:
import torch
input = torch.randn(1, 1, 100, 101, 102, requires_grad=True)
grid = torch.randn(1, 1, 1, 200000, 3, requires_grad=True)
output = torch.nn.functional.grid_sample(
input=input, grid=grid, mode="nearest", align_corners=False
)
y = output.sum()
y.backward()
input.grad.unique(), grid.grad.unique()
# prints : (tensor([0., 1., 2., 3., 4.]), tensor([0.]))
My interpretation of this is that the gradient for input
is the number of times a voxel in input
was sampled by grid
; the gradient for grid
is zero everywhere.
Following @gfox’s adaptation of torch.clamp
to have a gradient of one everywhere, I tried to write a custom backward pass for grid_sample
as follows:
from torch.cuda.amp import custom_bwd, custom_fwd
class DifferentiableNearestNeighbors(torch.autograd.Function):
"""
In the forward pass, behaves like `torch.nn.functional.grid_sample`.
In the backward pass, returns the default gradient for `input`, but
also returns a gradient for `grid` that is 1 everywhere.
"""
@staticmethod
@custom_fwd
def forward(ctx, input, grid, align_corners):
ctx.save_for_backward(input, grid)
return torch.nn.functional.grid_sample(
input, grid, mode="nearest", align_corners=align_corners
)
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input, grid, = ctx.saved_tensors
return grad_output.clone(), torch.zeros_like(grid), None
def differentiable_nearest_neighbors(input, grid, align_corners):
return DifferentiableNearestNeighbors.apply(input, grid, align_corners)
However, grad_output
is a tensor with the same shape as output
, i.e., (1, 1, 1, 1, 200000)
, but is one everywhere. My questions are then:
- How do you turn
grad_output
into the original gradient forinput
thatgrid_sample
returns? - Is it possible to simply return the original gradient for
input
but somehow customize the gradient forgrid
?