Background
Note that this is not a CUDA-related question!
Hi, I want to implement the following function on CUDA (which is already achieved).
"""
This is a function of calculating vertex normals where
xyz: (batch_size, number of vertices, 3)
faces: (batch_size, number of faces, 3)
"""
def computeVN(xyz, faces):
vertex_normals = torch.zeros_like(xyz)
vertex_per_face = grouping_operation(
xyz.permute(0, 2, 1).contiguous(),
faces
).permute(0, 2, 3, 1)
v2 = vertex_per_face[:, :, 2, :]
v1 = vertex_per_face[:, :, 1, :]
v0 = vertex_per_face[:, :, 0, :]
vec1 = v2 - v1 # (B, M, 3)
vec2 = v0 - v1
face_normals = torch.cross(vec1, vec2, -1)
for bidx, batch in enumerate(face_normals):
for fidx, face_normal in enumerate(batch):
v0, v1, v2 = faces[bidx, fidx]
vertex_normals[bidx, v0] += face_normal
vertex_normals[bidx, v1] += face_normal
vertex_normals[bidx, v2] += face_normal
from torch.nn import functional as F
vertex_normals = F.normalize(vertex_normals, dim=-1)
return vertex_normals
I use the following wrapper so that it can be applied in gradient descent:
class ComputeVertexNormals(Function):
@staticmethod
def forward(ctx, xyz, faces):
"""
xyz: (B, N, 3)
faces: (B, M, 3)
"""
ctx.for_backwards = (xyz, faces)
return _ext.compute_vertex_normals(xyz, faces)
@staticmethod
def backward(ctx, grad_out):
"""
grad_out: (B, N, 3)
"""
xyz, faces = ctx.for_backwards
return _ext.compute_vertex_normals_grad(grad_out, xyz, faces), None
compute_vertex_normals = ComputeVertexNormals.apply
However, my compute_vertex_normals_grad
is not working as expected. Now I want to debug it by inspecting each gradient.
Question
- Is it possible for me check the detailed implementation of each
grad_fn
? For example, the grad_fn forface_normals
is aLinalgCrossBackward0 object
. But I dont know what is inside this function.
Why Do I Want to Inspect its Implementation:
I have calculate the derivation by hands, and I got the following equaltion (Here we take v2 as an example):
But the result of this equation is totally different from the one by loss.backward()