Extend autograd.Function to support vmap and autograd

Hi,

I am trying to extend some pytorch3d classes (that are autograd.Function ) to support jacrev. I made the adjustment on C++ side and can confirm it is correct. The example below where _C.point_face_dist_backward is its backward method from pytorch3d. It can now take grad_dists as a 2D vector that supports how vmap/jacrev would call it.

# simulating torch.autograd.functional.jacobian
# to compute the jacobian row by row
grad_dists = torch.eye(points.shape[0], dtype=torch.float32, device=device)
grad_point_list = []
for grad_d in grad_dists:
    grad_points, _ = _C.point_face_dist_backward(
        points, tris, p2f_inds, grad_d, min_triangle_area
    )
    grad_point_list.append(grad_points)
grad_points = torch.stack(grad_point_list)

# simulating torch.func.jacrev
# under vmap, it will be called with an identity grad_dists
vmap_grad_dists = torch.eye(points.shape[0], dtype=torch.float32, device=device)
vmap_grad_points, _ = _C.point_face_dist_backward(
    points, tris, p2f_inds, vmap_grad_dists, min_triangle_area
)

np.testing.assert_equal(grad_points.numpy(), vmap_grad_points.numpy())

The problem:

  1. when this backward method is called with torch.autograd.functional.jacobian (computing the jacobian row-by-row), it returns grad_points with shape of N x 3. This works fine to build the final full jacobian.

  2. when it is called with jacrev, it returns grad_points with shape of M x N x 3. The final jacobian is a sparse 2D vector with a shape of M x (3N). I found this line in vmap.py reshapes the gradients (from M x N x 3 to M x (3N). While the shape is correct, it for some reason duplicates all the numbers row-by-row, so the resulting jacobian is a dense 2D vector with duplications.

The question:

  1. Does anyone have experience on extend autograd.Function to support vmap? The most useful sentence from the doc is “For example, torch.func.jacrev() performs vmap() over the backward pass. So if you’re only interested in using torch.func.jacrev(), only the backward() staticmethod needs to be vmappable.” but it does not have an example how to exactly do that when backward method is calling C++/CUDA code.

  2. If there is an example or a pointer, please also let me know!

Thanks!

For future reference, this is xposted and answered here Make `torch.autograd.Function` support `vmap` · Issue #128020 · pytorch/pytorch · GitHub.