How to inspect each grad_fn?

Background

Note that this is not a CUDA-related question!

Hi, I want to implement the following function on CUDA (which is already achieved).

"""
This is a function of calculating vertex normals where
      xyz: (batch_size, number of vertices, 3)
      faces: (batch_size, number of faces, 3)
"""
def computeVN(xyz, faces):
    vertex_normals = torch.zeros_like(xyz)
    vertex_per_face = grouping_operation(
        xyz.permute(0, 2, 1).contiguous(),
        faces
    ).permute(0, 2, 3, 1)
    v2 = vertex_per_face[:, :, 2, :]
    v1 = vertex_per_face[:, :, 1, :]
    v0 = vertex_per_face[:, :, 0, :]
    vec1 = v2 - v1 # (B, M, 3)
    vec2 = v0 - v1
    face_normals = torch.cross(vec1, vec2, -1)
    for bidx, batch in enumerate(face_normals):
        for fidx, face_normal in enumerate(batch):
            v0, v1, v2 = faces[bidx, fidx]
            vertex_normals[bidx, v0] += face_normal
            vertex_normals[bidx, v1] += face_normal
            vertex_normals[bidx, v2] += face_normal
    from torch.nn import functional as F
    vertex_normals = F.normalize(vertex_normals, dim=-1)
    return vertex_normals

I use the following wrapper so that it can be applied in gradient descent:

class ComputeVertexNormals(Function):
    @staticmethod
    def forward(ctx, xyz, faces):
        """
            xyz: (B, N, 3)
            faces: (B, M, 3)
        """
        ctx.for_backwards = (xyz, faces)
        return _ext.compute_vertex_normals(xyz, faces)
    @staticmethod
    def backward(ctx, grad_out):
        """
            grad_out: (B, N, 3)
        """
        xyz, faces = ctx.for_backwards
        return _ext.compute_vertex_normals_grad(grad_out, xyz, faces), None

compute_vertex_normals = ComputeVertexNormals.apply

However, my compute_vertex_normals_grad is not working as expected. Now I want to debug it by inspecting each gradient.

Question

  • Is it possible for me check the detailed implementation of each grad_fn? For example, the grad_fn for face_normals is a LinalgCrossBackward0 object. But I dont know what is inside this function.

Why Do I Want to Inspect its Implementation:

I have calculate the derivation by hands, and I got the following equaltion (Here we take v2 as an example):

But the result of this equation is totally different from the one by loss.backward()

One approach would be to check derivatives.yaml for the corresponding function name (I just used linalg_cross) and then to search for the defined backward implementations.

Thx for your help, but I still cannot find its grad_fn implementation. Is there any documents which can help me understand the source code of Pytorch?