Second Gradient Computation with autograd yield zeros

Hi everyone,

When I compute the gradient using torch.autograd.grad and then try to compute the second derivative via another call to autograd.grad, I unexpectedly get a tensor of all zeros. However, if I compute the second derivative using either:

  1. The Jacobian of the gradient and take its trace (i.e., sum of second partials), or
  2. Loop through each input component and compute its second derivative individually,

I get the correct analytical result. The problem is:

  • Using Jacobians: Although I get the correct result, the model’s loss behaves strangely during training and doesn’t converge.
  • Per-input gradient computation: This works, but it is extremely slow—not feasible for my application.

Has anyone encountered this issue with autograd.grad returning zero second derivatives?

Here is a small test case:

def minimal_test(x):
    xi = x[:, 0, :]
    xj = x[:, 1, :]
    diff = xi - xj
    r = torch.sqrt(torch.sum(diff ** 2, dim=-1) + 1e-10)
    sigma = torch.exp(torch.tensor(0.15))
    k = torch.sigmoid(torch.tensor(0.5))
    phi = torch.pow(sigma / r, 2 / k)
    
    energy = phi.sum(dim=-1)
    return - energy

# Forward pass
x_t = torch.tensor([[[ 0.4806, -0.5267], [ 0.5513,  0.6484]]], requires_grad=True)  # Simplified input, shape (B, N, D)

def psi(x_t):
    output = minimal_test(x_t)
    gradients = torch.autograd.grad(
        outputs=output, inputs=x_t, grad_outputs=torch.ones_like(output),
        retain_graph=True, create_graph=True,
    )[0]
    return gradients

gradients = psi(x_t)
grad_outputs = torch.ones_like(gradients)
divergence = torch.autograd.grad(outputs=gradients, inputs=x_t, grad_outputs=grad_outputs, create_graph=True)[0]
print("Gradients:", gradients)
print("Second Gradients (autograd):", divergence)

jac = jacobian(psi, x_t)

divergence = torch.zeros(x_t.shape[:-1], device=x_t.device)  # shape: (B, N)
for b in range(x_t.shape[0]):
    for n in range(x_t.shape[1]):
        # Jacobian at point (b, n) is D x D
        J = jac[b, n, :, b, n, :]  # shape: (D, D)
        divergence[b, n] = torch.trace(J)

print("Second Gradients (jacobian):", divergence)

divergence = torch.zeros(x_t.shape[:-1], device=x_t.device)  # shape: (B, N)
for b in range(x_t.shape[0]):
    for n in range(x_t.shape[1]):
        for d in range(x_t.shape[2]):
            divergence[b, n] += torch.autograd.grad(
                gradients[b, n, d], x_t, retain_graph=True, create_graph=True,
            )[0][b, n, d]


print("Second Gradients (autograd per input):", divergence)

This is the output:

Gradients: tensor([[[-0.1571, -2.6116],
         [ 0.1571,  2.6116]]], grad_fn=<AddBackward0>)
Second Gradients (autograd): tensor([[[0., 0.],
         [0., 0.]]], grad_fn=<AddBackward0>)
Second Gradients (jacobian): tensor([[-7.1409, -7.1409]])
Second Gradients (autograd per input): tensor([[-7.1409, -7.1409]], grad_fn=<CopySlices>)

P.S.
Setting any entry in grad_outputs to 0 will return a non-zero result, the problem is only when trying to get the full second derivative.

The second (and probably third, since it gives the same results) time you’re computing divergence, you start by computing the hessian: jac = jacobian(psi, x_t). This is the jacobian of the gradients with respect to the input, i.e. the hessian. Note that the result will be the same with hessian = torch.autograd.functional.hessian(minimal_test, x_t)

This hessian is of shape [B, N, D, B, N, D], which is why it’s so expensive to compute, no matter how you do it.

You’re then computing something which is the trace of parts of the hessian, and using this for divergence.

The first time you’re computing the divergence, however, you’re computing directly something that is equal to hessian.sum(dim=[0, 1, 2]). This is not the divergence, and I think it’s normal for this to be equal to zero in your case.

Ah, I see — the code worked for the Poisson process, but everything broke with Gibbs. I guess in the Poisson case, hessian.sum(dim=[0, 1, 2]) accidentally gives the correct result, even though the logic behind it is incorrect. Thanks for the explanation! :wink:

1 Like