I’m running into issues during backprop on a function that calculates skeleton joint positions in 3D space from rotation matrices. This function computes the global position of one joint at a time.
Implementation:
def torch_mat_to_pos(mat_rotations: torch.Tensor, parents: List[int], offsets: torch.Tensor) -> torch.Tensor:
"""
Convert rotation matrices to joint positions.
Args:
mat_rotations: Rotation matrices as tensor of shape (frames, joints, 3, 3).
parents: Parents of each joint as tensor of shape (joints,).
offsets: Offsets of each joint as tensor of shape (joints, 3).
Returns:
Joint positions as tensor of shape (frames, joints, 3).
"""
global_rot = torch.clone(mat_rotations)
global_pos = torch.zeros((mat_rotations.shape[0], mat_rotations.shape[1], 3),
dtype=mat_rotations.dtype, device=mat_rotations.device)
for i, parent in enumerate(parents):
if parent == -1:
continue
# multiply this joint's rotmat to the rotmat of its parent
global_rot[:, i] = mat_rotations[:, i] @ global_rot[:, parent]
k = offsets[i].repeat((mat_rotations.shape[0], 1, 1))
# multiply the offsets by the parent's rotmat
q = k @ global_rot[:, parent]
# add the offsets to the parent's position
global_pos[:, i] = global_pos[:, i] + torch.squeeze(q, 1)
return global_pos
I’ve enabled anomaly detection, and it seems like the issue is in this line:
global_rot[:, i] = mat_rotations[:, i] @ global_rot[:, parent]
I tried creating global_rot with torch.zero_like and ended up with the same result.
Also, note that I’m getting this error already at index 1 , right after this line (when testing backprop on global_rot[:, I]).
Any suggestions / ideas on how to solve this? Thanks in advance!