Do gradients persist through new Tensors (Non-neuralnetwork use of autograd)?

So I am trying to use pytorch for a non-neural net application for taking local pose estimation to global pose estimation. I am trying to use a custom loss function to compute the difference between a global 3d xyz coordinates and 2d uv image pixel coordinates. My question is do gradients persist when you use retain_grad but pass the tensor into a new tensor initialization?

I’m using a kinematic hand model for pose estimation so I am trying to find the gradients of the yaw/pitch/roll of the rotation matrix of the model wrist, translation matrix of the wrist, and flexion/abduction angles of the fingers after forward kinematics on the hand model. I randomly initialized the R, p, theta values and am using ground truth values of the image coordinates to perform gradient descent

yaw_w   = torch.tensor([0]).float()
pitch_w = torch.tensor([0]).float()
roll_w  = torch.tensor([0]).float()

p_wrist = torch.tensor([0., 50., 350.])
theta = torch.tensor([[0., 0., 0., 0.],
                      [0., 0., 0., 0.],
                      [0., 0., 0., 0.],
                      [0., 0., 0., 0.],
                      [0., 0., 0., 0.]])
def R_from_angs(yaw, pitch, roll):
    Rx = torch.tensor([[1, 0, 0],
                       [0, torch.cos(roll), -torch.sin(roll)],
                       [0, torch.sin(roll), torch.cos(roll)]]).float()

    Ry = torch.tensor([[torch.cos(pitch), 0, torch.sin(pitch)],
                       [0, 1, 0],
                       [-torch.sin(pitch), 0, torch.cos(pitch)]]).float()

    Rz = torch.tensor([[torch.cos(yaw), -torch.sin(yaw), 0],
                       [torch.sin(yaw), torch.cos(yaw), 0],
                       [0, 0, 1]]).float()
    R = torch.mm(torch.mm(Rz, Ry), Rx)
    return R
def loss_2d(xyz, uv_pred, c_scores):
    # projecting into image plane using camera intrinsics
    uv_proj = perspective_proj(xyz, torch.as_tensor(ci).float()) 
    uv_diff = uv_proj - torch.as_tensor(uv_pred).float()
    uv_norm = torch.norm(uv_diff, p=2, dim=1).view(21,1) ** 2
    loss = torch.sum(c_scores*uv_norm)
    return loss
yaw = torch.tensor(yaw_w, requires_grad=True)
pitch = torch.tensor(pitch_w, requires_grad=True)
roll = torch.tensor(roll_w, requires_grad=True)
p = torch.tensor(p_wrist, requires_grad=True)
theta = torch.tensor(theta, requires_grad=True)

R_wrist = R_from_angs(yaw, pitch, roll)

uv_pred, conf_score, xyz_pred, viz_crop = caffe_predictions(img)
# forward kinematics from hand Rotation, translation, and finger theta angles
xyz_fk_init = hand_fk(R_wrist, p, torch.from_numpy(avg_bone_length).float(), parent_unit_vec,theta)

loss = loss2d(xyz_fk_init, uv_pred, torch.as_tensor(conf_score).float())

When I perform backpropagation on the custom loss function above I lose the gradients for yaw/pitch/roll. I think it has to do with the fact that I am creating new tensors to create the transformation matrix. Any idea?

yaw.retain_grad()
pitch.retain_grad()
roll.retain_grad()
p.retain_grad()
theta.retain_grad()

loss.backward()
print(p.grad) #does return a value
print(pitch.grad) #gives me nothing

I know this is quite a wall of text so I really appreciate you going through and reading it!

I’m suffering that issue either.
Did you find the solution or any idea?
If you are, I’ll really appreciate you share any idea about it.

Gradients don’t persist through new Tensors.

There are probably many ways to create arbitrary matrices that preserves gradients, though (if that is your issue).

The easy way is to first create a base tensor that is the right size, and then you fill in each of its entries one by one:

theta = torch.tensor(10., requires_grad=True)
base = torch.ones(2, 2)
base[0][0] = torch.cos(theta)
base[0][1] = -torch.sin(theta)
base[1][0] = torch.sin(theta)
base[1][1] = torch.cos(theta)
base.backward(torch.ones_like(base))
theta.grad # tensor(1.0880)

A more advanced way is to create a custom autograd function. See PyTorch: Defining New autograd Functions — PyTorch Tutorials 1.7.0 documentation.

For example to rewrite the above using custom function we have:

class Rotation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, theta):
        rot = torch.tensor([
            [torch.cos(theta), -torch.sin(theta)],
            [torch.sin(theta), torch.cos(theta)]])
        ctx.save_for_backward(theta)
        return rot

    @staticmethod
    def backward(ctx, grad_output):
        theta, = ctx.saved_tensors
        grad_input = -torch.sin(theta)*grad_output[0][0] + -torch.cos(theta)*grad_output[0][1] + \
            torch.cos(theta)*grad_output[1][0] + -torch.sin(theta)*grad_output[1][1]
        return grad_input

rotation = Rotation.apply

theta = torch.tensor(10., requires_grad=True)
rot = rotation(theta)
rot.backward(torch.ones_like(rot))
theta.grad # tensor(1.0880)