How to inspect each grad_fn?

Background

Note that this is not a CUDA-related question!

Hi, I want to implement the following function on CUDA (which is already achieved).

"""
This is a function of calculating vertex normals where
      xyz: (batch_size, number of vertices, 3)
      faces: (batch_size, number of faces, 3)
"""
def computeVN(xyz, faces):
    vertex_normals = torch.zeros_like(xyz)
    vertex_per_face = grouping_operation(
        xyz.permute(0, 2, 1).contiguous(),
        faces
    ).permute(0, 2, 3, 1)
    v2 = vertex_per_face[:, :, 2, :]
    v1 = vertex_per_face[:, :, 1, :]
    v0 = vertex_per_face[:, :, 0, :]
    vec1 = v2 - v1 # (B, M, 3)
    vec2 = v0 - v1
    face_normals = torch.cross(vec1, vec2, -1)
    for bidx, batch in enumerate(face_normals):
        for fidx, face_normal in enumerate(batch):
            v0, v1, v2 = faces[bidx, fidx]
            vertex_normals[bidx, v0] += face_normal
            vertex_normals[bidx, v1] += face_normal
            vertex_normals[bidx, v2] += face_normal
    from torch.nn import functional as F
    vertex_normals = F.normalize(vertex_normals, dim=-1)
    return vertex_normals

I use the following wrapper so that it can be applied in gradient descent:

class ComputeVertexNormals(Function):
    @staticmethod
    def forward(ctx, xyz, faces):
        """
            xyz: (B, N, 3)
            faces: (B, M, 3)
        """
        ctx.for_backwards = (xyz, faces)
        return _ext.compute_vertex_normals(xyz, faces)
    @staticmethod
    def backward(ctx, grad_out):
        """
            grad_out: (B, N, 3)
        """
        xyz, faces = ctx.for_backwards
        return _ext.compute_vertex_normals_grad(grad_out, xyz, faces), None

compute_vertex_normals = ComputeVertexNormals.apply

However, my compute_vertex_normals_grad is not working as expected. Now I want to debug it by inspecting each gradient.

Question

  • Is it possible for me check the detailed implementation of each grad_fn? For example, the grad_fn for face_normals is a LinalgCrossBackward0 object. But I dont know what is inside this function.

Why Do I Want to Inspect its Implementation:

I have calculate the derivation by hands, and I got the following equaltion (Here we take v2 as an example):

image

But the result of this equation is totally different from the one by loss.backward()

One approach would be to check derivatives.yaml for the corresponding function name (I just used linalg_cross) and then to search for the defined backward implementations.

Thx for your help, but I still cannot find its grad_fn implementation. Is there any documents which can help me understand the source code of Pytorch?