Hi,
I am trying to extend some pytorch3d classes (that are autograd.Function ) to support jacrev. I made the adjustment on C++ side and can confirm it is correct. The example below where _C.point_face_dist_backward is its backward method from pytorch3d. It can now take grad_dists as a 2D vector that supports how vmap/jacrev would call it.
# simulating torch.autograd.functional.jacobian
# to compute the jacobian row by row
grad_dists = torch.eye(points.shape[0], dtype=torch.float32, device=device)
grad_point_list = []
for grad_d in grad_dists:
grad_points, _ = _C.point_face_dist_backward(
points, tris, p2f_inds, grad_d, min_triangle_area
)
grad_point_list.append(grad_points)
grad_points = torch.stack(grad_point_list)
# simulating torch.func.jacrev
# under vmap, it will be called with an identity grad_dists
vmap_grad_dists = torch.eye(points.shape[0], dtype=torch.float32, device=device)
vmap_grad_points, _ = _C.point_face_dist_backward(
points, tris, p2f_inds, vmap_grad_dists, min_triangle_area
)
np.testing.assert_equal(grad_points.numpy(), vmap_grad_points.numpy())
The problem:
-
when this
backwardmethod is called withtorch.autograd.functional.jacobian(computing the jacobian row-by-row), it returnsgrad_pointswith shape of N x 3. This works fine to build the final full jacobian. -
when it is called with
jacrev, it returnsgrad_pointswith shape of M x N x 3. The final jacobian is a sparse 2D vector with a shape of M x (3N). I found this line invmap.pyreshapes the gradients (from M x N x 3 to M x (3N). While the shape is correct, it for some reason duplicates all the numbers row-by-row, so the resulting jacobian is a dense 2D vector with duplications.
The question:
-
Does anyone have experience on extend autograd.Function to support vmap? The most useful sentence from the doc is “For example,
torch.func.jacrev()performsvmap()over the backward pass. So if you’re only interested in usingtorch.func.jacrev(), only thebackward()staticmethod needs to be vmappable.” but it does not have an example how to exactly do that whenbackwardmethod is calling C++/CUDA code. -
If there is an example or a pointer, please also let me know!
Thanks!