Hi everyone,
I’m working on computing the score matching objective for a softcore potential, and I’m running into performance issues when calculating the second-order gradients (i.e., the divergence of the gradient field).
Here’s a simplified example of what I’m doing:
def minimal_test(x):
xi = x[:, 0, :]
xj = x[:, 1, :]
diff = xi - xj
r = torch.sqrt(torch.sum(diff ** 2, dim=-1) + 1e-10)
sigma = torch.exp(torch.tensor(0.15))
k = torch.sigmoid(torch.tensor(0.5))
phi = torch.pow(sigma / r, 2 / k)
energy = phi.sum(dim=-1)
return -energy
x_t = torch.tensor([[[0.4806, -0.5267], [0.5513, 0.6484]]], requires_grad=True) # Shape: (B, N, D)
def psi(x_t):
output = minimal_test(x_t)
gradients = torch.autograd.grad(
outputs=output, inputs=x_t, grad_outputs=torch.ones_like(output),
retain_graph=True, create_graph=True,
)[0]
return gradients
gradients = psi(x_t)
# Compute divergence via second-order derivatives (inefficient!)
divergence = torch.zeros(x_t.shape[0], x_t.shape[1])
for d in range(x_t.shape[-1]):
for i in range(x_t.shape[-2]):
second_grad = torch.autograd.grad(
gradients[:, i, d].sum(), x_t, create_graph=True)[0][:, i, d] # shape: (B,)
divergence[:, i] += second_grad
print("Second Gradients (autograd per input):", divergence)
The issue is that this per-element loop over i
and d
is very slow, especially when scaling up to more particles or higher dimensions. I tried using vmap
, but couldn’t make it work due to the indexing gradients[:, i, d]
, which seems to break batching.
Is there any cleaner or more efficient way to compute this divergence without having to loop over each element individually?
Thanks in advance for any help or pointers!