How to use the gradient of an intermediate variable in the computation graph of a later variable

point_2 = torch.tensor([0.2, 0.8], device=device, requires_grad=True)
p = torch.cat((point_2, torch.tensor([0], device=device)), 0)

x_verts = torch.tensor([0.0, 1.0, 0.0], device=device, requires_grad=True)
y_verts = torch.tensor([0.0, 0.0, 1.0], device=device, requires_grad=True)
z_verts = torch.tensor([0.1, -0.1, 0.2], device=device, requires_grad=True)

v1_2d = torch.cat((torch.index_select(x_verts, 0, torch.tensor([0])), torch.index_select(y_verts, 0, torch.tensor([0])), torch.tensor([0])))
v2_2d = torch.cat((torch.index_select(x_verts, 0, torch.tensor([1])), torch.index_select(y_verts, 0, torch.tensor([1])), torch.tensor([0])))
v3_2d = torch.cat((torch.index_select(x_verts, 0, torch.tensor([2])), torch.index_select(y_verts, 0, torch.tensor([2])), torch.tensor([0])))


area_3 = torch.cross(v2_2d - v1_2d, v3_2d - v1_2d)
area = torch.index_select(area_3, 0, torch.tensor([2]))

alpha_3 = 0.5 * torch.cross(v2_2d - p, v3_2d - p) / area
beta_3 = 0.5 * torch.cross(v3_2d - p, v1_2d - p) / area
gamma_3 = 0.5 * torch.cross(v1_2d - p, v2_2d - p) / area

alpha = torch.index_select(alpha_3, 0, torch.tensor([2]))
beta = torch.index_select(beta_3, 0, torch.tensor([2]))
gamma = torch.index_select(gamma_3, 0, torch.tensor([2]))

z = alpha * torch.index_select(z_verts, 0, torch.tensor([0])) + beta * torch.index_select(z_verts, 0, torch.tensor([1])) + gamma * torch.index_select(z_verts, 0, torch.tensor([2]))

z.backward()


grad_norm = torch.norm(point_2.grad ) # <= disconnection

f = torch.tanh(10.0 * (grad_norm - 2.0))
f.backward() # <= error

print(x_verts.grad)
print(y_verts.grad)
print(z_verts.grad)

I have this code which fails because I use a variable’s .grad value as input to another variable. How can I fix this?

You can write a custom autograd function to define the gradient of your f term directly, docs here