Hi,
I am trying to extend some pytorch3d
classes (that are autograd.Function
) to support jacrev
. I made the adjustment on C++ side and can confirm it is correct. The example below where _C.point_face_dist_backward
is its backward
method from pytorch3d
. It can now take grad_dists
as a 2D vector that supports how vmap/jacrev would call it.
# simulating torch.autograd.functional.jacobian
# to compute the jacobian row by row
grad_dists = torch.eye(points.shape[0], dtype=torch.float32, device=device)
grad_point_list = []
for grad_d in grad_dists:
grad_points, _ = _C.point_face_dist_backward(
points, tris, p2f_inds, grad_d, min_triangle_area
)
grad_point_list.append(grad_points)
grad_points = torch.stack(grad_point_list)
# simulating torch.func.jacrev
# under vmap, it will be called with an identity grad_dists
vmap_grad_dists = torch.eye(points.shape[0], dtype=torch.float32, device=device)
vmap_grad_points, _ = _C.point_face_dist_backward(
points, tris, p2f_inds, vmap_grad_dists, min_triangle_area
)
np.testing.assert_equal(grad_points.numpy(), vmap_grad_points.numpy())
The problem:
-
when this
backward
method is called withtorch.autograd.functional.jacobian
(computing the jacobian row-by-row), it returnsgrad_points
with shape of N x 3. This works fine to build the final full jacobian. -
when it is called with
jacrev
, it returnsgrad_points
with shape of M x N x 3. The final jacobian is a sparse 2D vector with a shape of M x (3N). I found this line invmap.py
reshapes the gradients (from M x N x 3 to M x (3N). While the shape is correct, it for some reason duplicates all the numbers row-by-row, so the resulting jacobian is a dense 2D vector with duplications.
The question:
-
Does anyone have experience on extend autograd.Function to support vmap? The most useful sentence from the doc is “For example,
torch.func.jacrev()
performsvmap()
over the backward pass. So if you’re only interested in usingtorch.func.jacrev()
, only thebackward()
staticmethod needs to be vmappable.” but it does not have an example how to exactly do that whenbackward
method is calling C++/CUDA code. -
If there is an example or a pointer, please also let me know!
Thanks!