RuntimeError: derivative for aten::grid_sampler_3d_backward is not implemented (minimal reproducible example attached)

Hello! So I’m trying to render a normal map from a density field. One of the ways to do this to grab the gradients of the density field w.r.t. to the point locations, and later the gradients participates in loss calculation. An example

import torch
import torch.nn.functional as F

feat = torch.randn(1, 1, 200, 200, 16) # features
inds = torch.rand(1, 1, 1, 420000, 3) # point locations

feat.requires_grad_(True)
inds.requires_grad_(True)

feat_ = F.grid_sample(feat, inds, 'bilinear', 'zeros', align_corners=True) # density field

gradients = torch.autograd.grad(feat_, inds, torch.ones_like(feat_), retain_graph=True, create_graph=True)[0] # (1, 1, 1, 420000, 3)

loss = gradients.mean() # using as an example

loss.backward()

which gives RuntimeError: aten::grid_sampler_3d_backward is not implemented. Essentially I need to grab such gradient, and to do that I’m backpropagating twice (I think). Wondering 1) if/when will this feature be supported? 2) there might be other ways to do what I’m trying to do that can work? Anyone has any ideas? Thanks

torch==2.1.2+cu121