Fast computation of the Hessian diagonal

Hi,

In my use case, I need to compute the diagonal of the Hessian for 3D tensors.

The code below illustrates what I need to compute. It works, however, it does not scale well if one increases the dimension of the tensor.

Are there faster way to achieve the same computation?

I am on torch=1.10.0.

from time import time

import torch
from torch._vmap_internals import _vmap as vmap

def grad(outputs, inputs, **kwargs):
    """
    Compute the trace of the jacobian of a 3D tensor.

    grad_{ijk} = d (outputs_{ijk}) / d (inputs_{ijk})

    See: https://discuss.pytorch.org/t/jacobian-functional-api-batch-respecting-jacobian/84571/7?u=amerlo94
    """

    shape = outputs.shape
    bs = shape[0]
    n = shape[1] * shape[2]

    outputs = outputs.view(bs, n)
    outputs = outputs.sum(axis=0)
    grad_outputs = torch.eye(n, dtype=outputs.dtype, device=outputs.device)

    def get_vjp(v):
        return torch.autograd.grad(outputs, inputs, v, **kwargs)[0]

    vjp = vmap(get_vjp)
    grad = vjp(grad_outputs)

    return grad.T.view(shape)


n = 16
b = torch.randn(1, n, n)


def f(x):
    return x ** 2 * b


bs = 128
iterations = 1000

x = torch.rand(bs).requires_grad_(True)
x = x.view(-1, 1, 1)
y = f(x)

#  Check gradients
jac = grad(y, x, create_graph=True)
hes = grad(jac, x)
assert torch.allclose(jac, 2 * x * b)
assert torch.allclose(hes, 2 * b)

#  Time gradients
t0 = time()
for _ in range(iterations):
    jac = grad(y, x, create_graph=True)
    grad(jac, x)

tottime = time() - t0

time_per_batch = tottime / iterations / bs
print(f"Time per batch: {time_per_batch * 1e6:.2f} us")