I’m running into issues during backprop on a function that calculates skeleton joint positions in 3D space from rotation matrices. This function computes the global position of one joint at a time.
def torch_mat_to_pos(mat_rotations: torch.Tensor, parents: List[int], offsets: torch.Tensor) -> torch.Tensor: """ Convert rotation matrices to joint positions. Args: mat_rotations: Rotation matrices as tensor of shape (frames, joints, 3, 3). parents: Parents of each joint as tensor of shape (joints,). offsets: Offsets of each joint as tensor of shape (joints, 3). Returns: Joint positions as tensor of shape (frames, joints, 3). """ global_rot = torch.clone(mat_rotations) global_pos = torch.zeros((mat_rotations.shape, mat_rotations.shape, 3), dtype=mat_rotations.dtype, device=mat_rotations.device) for i, parent in enumerate(parents): if parent == -1: continue # multiply this joint's rotmat to the rotmat of its parent global_rot[:, i] = mat_rotations[:, i] @ global_rot[:, parent] k = offsets[i].repeat((mat_rotations.shape, 1, 1)) # multiply the offsets by the parent's rotmat q = k @ global_rot[:, parent] # add the offsets to the parent's position global_pos[:, i] = global_pos[:, i] + torch.squeeze(q, 1) return global_pos
I’ve enabled anomaly detection, and it seems like the issue is in this line:
global_rot[:, i] = mat_rotations[:, i] @ global_rot[:, parent]
I tried creating global_rot with torch.zero_like and ended up with the same result.
Also, note that I’m getting this error already at index 1 , right after this line (when testing backprop on global_rot[:, I]).
Any suggestions / ideas on how to solve this? Thanks in advance!