My goal is to compute the hessian wrt the input to the function / neural net, not wrt its parameters.
Suppose you have a vector-valued function in pytorch f:R^n → R^m.
Computing its hessian should return a [M, N, N] tensor, and its diagonals would return a [M,N] matrix
Computing this whole hessian and then diagonalizing is too expensive for large functions like neural nets.
There also already exists a technique provided in the Pytorch docs to do something similar when dealing with a scalar-valued function (f:R^n → R, and its hessian diagonal would return a [N] vector) using hvp
def hvp(f, x, v):
return jvp(grad(f), (x,), (v,))[1]
but using it directly on vector-valued functions returns an error.
I’ve tried coming up with my own implementation with hvp_vecfwd
# My attempt
def hvp_vecfwd(f, x, v):
return jvp(jacrev(f), (x,), (v,))[1]
It does produce the correct shape, but using the cotangents
torch.ones_like(x_single)
the values inside are not all correct even on simple vector functions (f_single_vector in code below). Is this problem with my choice of cotangents, my implementation, or both?
Any help or insights on this would be great!
import torch
from torch.func import vmap, jvp, vjp, grad, jacrev, jacfwd, hessian
from functools import partial
x_single = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)
x_batch = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]], requires_grad=False)
layer = torch.nn.Linear(3, 5).requires_grad_(False)
# For Scalar Functions
# Provided by PyTorch Docs
def hvp(f, x, v):
return jvp(grad(f), (x,), (v,))[1]
# For Vector Functions
# My attempt
def hvp_vecfwd(f, x, v):
return jvp(jacrev(f), (x,), (v,))[1]
# Simple functions for debugging
# Batched f: [B, R^N] -> [B, R]
def f_batch_scalar(x):
return torch.sum(x*x, dim=-1)
# f: R^N -> R^M
def f_single_vector(x):
return torch.sigmoid(layer(x))
# Batched f: [B, R^N] -> [B, R^M]
def f_batch_vector(x):
return torch.sigmoid(layer(x))
print("Batched Full Hessian of Scalar Function")
# Batched f: [B, R^N] -> R, d^2f/dx^2 = [B, N, N]
print(vmap(hessian(f_batch_scalar))(x_batch))
print("Batched Diagonal Hessian of Scalar Function")
# Batched f: [B, R^N] -> R, diag[ H] = [B, N]
print(vmap(hvp, in_dims=(None, 0, 0))(f_batch_scalar, x_batch, torch.ones_like(x_batch)))
print("Non-Batched Full Hessian of Vector Function")
# f: R^N -> R^M, d^2f/dx^2 = [M, N, N]
print(hessian(f_single_vector)(x_single))
print("Non-Batched Diagonal Hessian of Vector Function")
# f: R^N -> R^M, diag[H] = [M, N]
# WRONG
print(hvp_vecfwd(f_single_vector, x_single, torch.ones_like(x_single)))
print("Batched Full Hessian of Vector Function")
# Batched f: [B, R^N} -> [B, R^M] d^2f/dx^2 = [B, M, N, N]
print(vmap(hessian(f_batch_vector))(x_batch))
print("Batched Diagonal Hessian of Vector Function")
# Batched f: [B, R^N} -> [B, R^M], diag[H] = [B, M, N]
# WRONG
print(vmap(hvp_vecfwd, in_dims=(None, 0, 0))(f_batch_vector, x_batch, torch.ones_like(x_batch)))