How to calculate Laplacian (sum of 2nd derivatives) in one step

Let’s say f: R^N -> R, and f is batched so that f(xs) -> ys, where xs.shape == (n_batch, n_dim) and ys.shape = (n_batch,). I want to calculate the Laplacian of f, that is, \sum_i d^2/dx_i^2 f(x).

At the moment, I evaluate ys = f(xs), and then iterate over n_dim, and for each i a I calculate

diys = grad(ys, xis, ones(n_batch), create_graph=True)
didiys = grad(diys, xis, ones(n_batch), retain_graph=True)

and then sum all didiys. Is there a way to do this in on step rather than iterating over dimensions?

It is somewhat related to Issue #7786 where f_1, f_2 would be d/dx1 f(x), d/dx2 f(x).

Did you find any solution to this? I also need to calculate Laplacians quite often, my current way of doing this by iterating (see below) is quite slow (and CPU based…):

def laplace(fx: torch.Tensor, x: torch.Tensor):
    """
    Laplacian (= sum of 2nd derivations)
     of (evaluated) nd->1d-function fx w.r.t. nd-tensor x
    :rtype: torch.Tensor
    """
    dfx = fx
    dfx = torch.autograd.grad(dfx, x, create_graph=True)[0]
    ddfx = []
    for i in range(len(x)):
        vec = torch.tensor([(1 if i == j else 0) for j in range(len(dfx))], dtype=torch.float)
        ddfx += [torch.autograd.grad(
            dfx,
            x,
            create_graph=True,
            grad_outputs=vec
        )[0][i]]
    ret = sum(ddfx)
    return ret

Can you use the new Hessian function introduced in 1.5? https://pytorch.org/docs/stable/autograd.html#torch.autograd.functional.hessian

No, I didn’t, I ended up using this:

def laplacian(xs, f, create_graph=False, keep_graph=None, return_grad=False):
    xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
    xs_flat = torch.stack(xis, dim=1)
    ys = f(xs_flat.view_as(xs))
    (ys_g, *other) = ys if isinstance(ys, tuple) else (ys, ())
    ones = torch.ones_like(ys_g)
    (dy_dxs,) = torch.autograd.grad(ys_g, xs_flat, ones, create_graph=True)
    lap_ys = sum(
        torch.autograd.grad(
            dy_dxi, xi, ones, retain_graph=True, create_graph=create_graph
        )[0]
        for xi, dy_dxi in zip(xis, (dy_dxs[..., i] for i in range(len(xis))))
    )
    if not (create_graph if keep_graph is None else keep_graph):
        ys = (ys_g.detach(), *other) if isinstance(ys, tuple) else ys.detach()
    result = lap_ys, ys
    if return_grad:
        result += (dy_dxs.detach().view_as(xs),)
    return result

But actually to calculate the Laplacian, you need to calculate the Hessian, there is no way around that. So as long as your batch is large enough, the for loop should introduce no significant overheard compared to the theoretically optimal implementation. Laplacian will always scale as N^2.

1 Like