Finding the Diagonal of Hessian wrt Input for Vector-Valued Functions

My goal is to compute the hessian wrt the input to the function / neural net, not wrt its parameters.

Suppose you have a vector-valued function in pytorch f:R^n → R^m.
Computing its hessian should return a [M, N, N] tensor, and its diagonals would return a [M,N] matrix

Computing this whole hessian and then diagonalizing is too expensive for large functions like neural nets.

There also already exists a technique provided in the Pytorch docs to do something similar when dealing with a scalar-valued function (f:R^n → R, and its hessian diagonal would return a [N] vector) using hvp

def hvp(f, x, v):
    return jvp(grad(f), (x,), (v,))[1]

but using it directly on vector-valued functions returns an error.

I’ve tried coming up with my own implementation with hvp_vecfwd

# My attempt
def hvp_vecfwd(f, x, v):
    return jvp(jacrev(f), (x,), (v,))[1]

It does produce the correct shape, but using the cotangents

 torch.ones_like(x_single)

the values inside are not all correct even on simple vector functions (f_single_vector in code below). Is this problem with my choice of cotangents, my implementation, or both?

Any help or insights on this would be great!

import torch
from torch.func import vmap, jvp, vjp, grad, jacrev, jacfwd, hessian
from functools import partial

x_single = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)
x_batch = torch.tensor([[1.0, 2.0, 3.0],
                        [4.0, 5.0, 6.0],
                        [7.0, 8.0, 9.0]], requires_grad=False)
layer = torch.nn.Linear(3, 5).requires_grad_(False)


# For Scalar Functions
# Provided by PyTorch Docs
def hvp(f, x, v):
    return jvp(grad(f), (x,), (v,))[1]

# For Vector Functions
# My attempt
def hvp_vecfwd(f, x, v):
    return jvp(jacrev(f), (x,), (v,))[1]

# Simple functions for debugging

# Batched f: [B, R^N] -> [B, R]
def f_batch_scalar(x):
    return torch.sum(x*x, dim=-1)


# f: R^N -> R^M
def f_single_vector(x):
    return torch.sigmoid(layer(x))


# Batched f: [B, R^N] -> [B, R^M]
def f_batch_vector(x):
    return torch.sigmoid(layer(x))


print("Batched Full Hessian of Scalar Function")
# Batched f: [B, R^N] -> R, d^2f/dx^2 = [B, N, N]
print(vmap(hessian(f_batch_scalar))(x_batch))

print("Batched Diagonal Hessian of Scalar Function")
# Batched f: [B, R^N] -> R, diag[ H] = [B, N]
print(vmap(hvp, in_dims=(None, 0, 0))(f_batch_scalar, x_batch, torch.ones_like(x_batch)))

print("Non-Batched Full Hessian of Vector Function")
# f: R^N -> R^M, d^2f/dx^2 = [M, N, N]
print(hessian(f_single_vector)(x_single))

print("Non-Batched Diagonal Hessian of Vector Function")
# f: R^N -> R^M, diag[H] = [M, N]
# WRONG
print(hvp_vecfwd(f_single_vector, x_single,  torch.ones_like(x_single)))

print("Batched Full Hessian of Vector Function")
# Batched f: [B, R^N} -> [B, R^M] d^2f/dx^2 = [B, M, N, N]
print(vmap(hessian(f_batch_vector))(x_batch))

print("Batched Diagonal Hessian of Vector Function")
# Batched f: [B, R^N} -> [B, R^M], diag[H] = [B, M, N]
# WRONG
print(vmap(hvp_vecfwd, in_dims=(None, 0, 0))(f_batch_vector, x_batch, torch.ones_like(x_batch)))

Hi jdeo!

Why do you say this? You say that you want the (diagonal of the) hessian with
respect to your R^n input, so the size of the network doesn’t affect how large
your [N, N] hessian matrix is.

Consider your gradient of shape [N]. To compute the [0, 0] element of your
hessian you compute the gradient of the [0] element of your gradient and
keep just the [0] element of that result. But (for any realistic network) you
wouldn’t save any significant work by computing just the [0] element of
that gradient.

When you compute the gradient of the [0] element of the first gradient,
autograd chains together a series of jacobian-vector products for each layer.
Roughly speaking, all of the terms get mixed together in each layer, so you
have to perform the full jacobian-vector products for all of the intermediate
layers even though you know that you will be discarding all but one element
of the final gradient.

Just go ahead and let autograd compute the full [N, N] hessian, and then
discard all but the diagonal elements.

You can’t use a hessian-vector product to obtain the diagonal of the hessian.

Consider some [N, N] matrix. There is no vector (which would have shape [N])
for which matrix times vector gives you the diagonal of the matrix. That’s just
not how matrix-vector multiplication works.

Leaving aside the issue of obtaining the diagonal of the hessian, are you
asking more about how to do this for a vector-valued function, for a batch
of independent inputs, or for both?

Best.

K. Frank