Gradient Hessian product error

I’m trying to implement a hyperparameter optimization method outlined in this paper. The hyperparameter optimization depends on approximating the inverse of a Hessian matrix using a Neumann series, which is where I’m running into problems.

Essentially, after updating the model parameters, the main part of the problem is calculating the product of the gradient of the validation loss wrt to the model parameters and the inverse of the second derivative of the training loss wrt to the model parameters. They use a truncation of the Neumann series to approximate this inverse (I didn’t see how to typeset math on here so I’ve just included images). Rather than instantiate large second derivative matrices, we rely on vector Hessian products that we calculate iteratively to get each term in the Neumann series:

The idea is that we can calculate each term as v - grad(dlt_dw, w, grad_outputs=v) and update v.

I have the following (contrived) example that shows what I’m currently running. I have a neural network with a single hyperparameter, the weight decay lambda parameter. You can see I’m trying calculate the first ten terms in the Neumann series

import torch
import torch.nn as nn
from torch.autograd import grad
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
torch.manual_seed(123)

# Create some dummy data for training
n = 1000
x_train = torch.randn(n, 1)
y_train = 2 * x_train + 1 + torch.randn(n, 1) * 0.1

# Create some dummy data for validation
x_val = torch.randn(n, 1)
y_val = 2 * x_val + 1 + torch.randn(n, 1) * 0.1

x_train = torch.as_tensor(x_train).float()
y_train = torch.as_tensor(y_train).float().view(n, -1)
train_dat = TensorDataset(x_train, y_train)

x_val = torch.as_tensor(x_val).float()
y_val = torch.as_tensor(y_val).float().view(n, -1)

batch_size = 32
train_loader = DataLoader(dataset=train_dat,
                          shuffle=True,
                          batch_size=batch_size)

model = nn.Sequential(nn.Linear(1, 8),
                      nn.ReLU(),
                      nn.Linear(8, 4),
                      nn.ReLU(),
                      nn.Linear(4, 1))
loss_fn = nn.MSELoss()
lambda_penalty = torch.nn.Parameter(torch.tensor(0.1, requires_grad=True))
lr = 0.01

# Training loop
num_epochs = 1
for epoch in range(num_epochs):
    # Putting this here because in the full code the weight updates will occur conditional on the updated lambda
    optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=lambda_penalty.item())
    for x_batch, y_batch in train_loader:
        x_batch = x_batch
        y_batch = y_batch
        model.train()
        optimizer.zero_grad()
        y_pred = model(x_batch)
        loss = loss_fn(y_pred, y_batch)
        loss.backward()
        optimizer.step()

    # Update lambda
    optimizer.zero_grad()
    regularization_term = 1 / 2 * lambda_penalty * torch.norm(torch.cat([param.view(-1) for param in model.parameters()]), p=2) ** 2
    y_train_pred = model(x_train)
    y_val_pred = model(x_val)
    loss_train = loss_fn(y_train_pred, y_train) + regularization_term
    loss_val = loss_fn(y_val_pred, y_val)

    dlv_dw = grad(loss_val, model.parameters(), create_graph=True)
    dlt_dw = grad(loss_train, model.parameters(), create_graph=True)

    dlv_dw = torch.cat([grad.view(-1) for grad in dlv_dw])
    dlt_dw = torch.cat([grad.view(-1) for grad in dlt_dw])

    with torch.no_grad():
        v = dlv_dw.detach().clone()
        p = dlv_dw.detach().clone()
        for i in range(10):
            tmp_v = grad(dlt_dw, model.parameters(), grad_outputs=v, retain_graph=True)
            tmp_v = torch.cat([grad.view(-1) for grad in tmp_v])
            v = v - tmp_v
            print(v[0])

        # REST OF LAMBDA UPDATE CODE NOT SHOWN
        # ....

The problem is that the elements of the series are exploding and the series will thus diverge. See for example the first 6 elements of the v vector on all 10 iterations from the print line:

tensor([ 0.1156, -0.1112, -0.1701,  0.1533,  0.2598])
tensor([ 0.3945, -0.2782, -0.4173,  0.3951,  0.6880])
tensor([ 1.1269, -0.6591, -0.9736,  1.0918,  1.9326])
tensor([ 3.2012, -1.5501, -2.2807,  2.9116,  5.2814])
tensor([ 8.9077, -3.6360, -5.3440,  7.9048, 14.5893])
tensor([ 24.7851,  -8.5381, -12.5502,  21.3749,  40.1064])
tensor([ 68.6840, -20.1017, -29.5685,  58.0873, 110.4438])
tensor([190.1629, -47.4884, -69.9186, 157.9792, 303.8438])
tensor([ 525.7173, -112.6450, -166.0421,  430.5572,  836.0579])
tensor([1452.1401, -268.4406, -396.2061, 1174.8359, 2299.8997])

Now, it’s possible the Neumann series doesn’t actually converge but I have a feeling it’s more likely that I am coding this incorrectly, either calculating the second derivatives incorrectly or allowing gradients to accumulate in some way. Does anyone see why this code might be doing this?

In this case, I believe it’s just a function of this particular example. A toy example where the spectral radius condition (max of the absolute value of the eigenvalues less than 1) is satisfied has convergence of the series:

import torch
from torch.func import hessian

def f(x):
    return x.pow(3).sum()

x = torch.tensor([0.05, 0.05], requires_grad=True)
v = torch.ones(2)

H = hessian(f)(x)
print(H)
H_inv = torch.tensor([[1 / 0.3, 0], [0, 1 / 0.3]])

g = torch.autograd.grad(f(x), x, create_graph=True)

with torch.no_grad():
    vi = v.clone()
    p = v.clone()
    for i in range(50):
        tmp_v = torch.autograd.grad(g, x, grad_outputs=vi, retain_graph=True)[0]
        vi = vi - tmp_v
        p = p + vi

print(torch.allclose(v @ H_inv, p))