Efficient Computation of Second-Order Gradients

Hi everyone,

I’m working on computing the score matching objective for a softcore potential, and I’m running into performance issues when calculating the second-order gradients (i.e., the divergence of the gradient field).

Here’s a simplified example of what I’m doing:

def minimal_test(x):
    xi = x[:, 0, :]
    xj = x[:, 1, :]
    diff = xi - xj
    r = torch.sqrt(torch.sum(diff ** 2, dim=-1) + 1e-10)
    sigma = torch.exp(torch.tensor(0.15))
    k = torch.sigmoid(torch.tensor(0.5))
    phi = torch.pow(sigma / r, 2 / k)
    
    energy = phi.sum(dim=-1)
    return -energy

x_t = torch.tensor([[[0.4806, -0.5267], [0.5513, 0.6484]]], requires_grad=True)  # Shape: (B, N, D)

def psi(x_t):
    output = minimal_test(x_t)
    gradients = torch.autograd.grad(
        outputs=output, inputs=x_t, grad_outputs=torch.ones_like(output),
        retain_graph=True, create_graph=True,
    )[0]
    return gradients

gradients = psi(x_t)

# Compute divergence via second-order derivatives (inefficient!)
divergence = torch.zeros(x_t.shape[0], x_t.shape[1])
for d in range(x_t.shape[-1]):
    for i in range(x_t.shape[-2]):
        second_grad = torch.autograd.grad(
            gradients[:, i, d].sum(), x_t, create_graph=True)[0][:, i, d]  # shape: (B,)
        divergence[:, i] += second_grad

print("Second Gradients (autograd per input):", divergence)

The issue is that this per-element loop over i and d is very slow, especially when scaling up to more particles or higher dimensions. I tried using vmap, but couldn’t make it work due to the indexing gradients[:, i, d], which seems to break batching.

Is there any cleaner or more efficient way to compute this divergence without having to loop over each element individually?

Thanks in advance for any help or pointers!

If I understand correctly you are computing the trace of the hessian (aka the laplacian)? There is a function for computing the hessian.
Since you only need the elements on the diagonal this might compute a lot of elements you don’t need but I still think there might be a chance that this is faster?
Efficient implementations (using vmap to handle vectorization) following that idea have also been discussed here.

If this is still too inefficient (and this might be a moonshot) there are methods for estimating the trace of the hessian but they are more involved.

hope that helps