I’m trying to implement a hyperparameter optimization method outlined in this paper. The hyperparameter optimization depends on approximating the inverse of a Hessian matrix using a Neumann series, which is where I’m running into problems.
Essentially, after updating the model parameters, the main part of the problem is calculating the product of the gradient of the validation loss wrt to the model parameters and the inverse of the second derivative of the training loss wrt to the model parameters. They use a truncation of the Neumann series to approximate this inverse (I didn’t see how to typeset math on here so I’ve just included images). Rather than instantiate large second derivative matrices, we rely on vector Hessian products that we calculate iteratively to get each term in the Neumann series:
The idea is that we can calculate each term as v - grad(dlt_dw, w, grad_outputs=v)
and update v.
I have the following (contrived) example that shows what I’m currently running. I have a neural network with a single hyperparameter, the weight decay lambda parameter. You can see I’m trying calculate the first ten terms in the Neumann series
import torch
import torch.nn as nn
from torch.autograd import grad
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
torch.manual_seed(123)
# Create some dummy data for training
n = 1000
x_train = torch.randn(n, 1)
y_train = 2 * x_train + 1 + torch.randn(n, 1) * 0.1
# Create some dummy data for validation
x_val = torch.randn(n, 1)
y_val = 2 * x_val + 1 + torch.randn(n, 1) * 0.1
x_train = torch.as_tensor(x_train).float()
y_train = torch.as_tensor(y_train).float().view(n, -1)
train_dat = TensorDataset(x_train, y_train)
x_val = torch.as_tensor(x_val).float()
y_val = torch.as_tensor(y_val).float().view(n, -1)
batch_size = 32
train_loader = DataLoader(dataset=train_dat,
shuffle=True,
batch_size=batch_size)
model = nn.Sequential(nn.Linear(1, 8),
nn.ReLU(),
nn.Linear(8, 4),
nn.ReLU(),
nn.Linear(4, 1))
loss_fn = nn.MSELoss()
lambda_penalty = torch.nn.Parameter(torch.tensor(0.1, requires_grad=True))
lr = 0.01
# Training loop
num_epochs = 1
for epoch in range(num_epochs):
# Putting this here because in the full code the weight updates will occur conditional on the updated lambda
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=lambda_penalty.item())
for x_batch, y_batch in train_loader:
x_batch = x_batch
y_batch = y_batch
model.train()
optimizer.zero_grad()
y_pred = model(x_batch)
loss = loss_fn(y_pred, y_batch)
loss.backward()
optimizer.step()
# Update lambda
optimizer.zero_grad()
regularization_term = 1 / 2 * lambda_penalty * torch.norm(torch.cat([param.view(-1) for param in model.parameters()]), p=2) ** 2
y_train_pred = model(x_train)
y_val_pred = model(x_val)
loss_train = loss_fn(y_train_pred, y_train) + regularization_term
loss_val = loss_fn(y_val_pred, y_val)
dlv_dw = grad(loss_val, model.parameters(), create_graph=True)
dlt_dw = grad(loss_train, model.parameters(), create_graph=True)
dlv_dw = torch.cat([grad.view(-1) for grad in dlv_dw])
dlt_dw = torch.cat([grad.view(-1) for grad in dlt_dw])
with torch.no_grad():
v = dlv_dw.detach().clone()
p = dlv_dw.detach().clone()
for i in range(10):
tmp_v = grad(dlt_dw, model.parameters(), grad_outputs=v, retain_graph=True)
tmp_v = torch.cat([grad.view(-1) for grad in tmp_v])
v = v - tmp_v
print(v[0])
# REST OF LAMBDA UPDATE CODE NOT SHOWN
# ....
The problem is that the elements of the series are exploding and the series will thus diverge. See for example the first 6 elements of the v vector on all 10 iterations from the print line:
tensor([ 0.1156, -0.1112, -0.1701, 0.1533, 0.2598])
tensor([ 0.3945, -0.2782, -0.4173, 0.3951, 0.6880])
tensor([ 1.1269, -0.6591, -0.9736, 1.0918, 1.9326])
tensor([ 3.2012, -1.5501, -2.2807, 2.9116, 5.2814])
tensor([ 8.9077, -3.6360, -5.3440, 7.9048, 14.5893])
tensor([ 24.7851, -8.5381, -12.5502, 21.3749, 40.1064])
tensor([ 68.6840, -20.1017, -29.5685, 58.0873, 110.4438])
tensor([190.1629, -47.4884, -69.9186, 157.9792, 303.8438])
tensor([ 525.7173, -112.6450, -166.0421, 430.5572, 836.0579])
tensor([1452.1401, -268.4406, -396.2061, 1174.8359, 2299.8997])
Now, it’s possible the Neumann series doesn’t actually converge but I have a feeling it’s more likely that I am coding this incorrectly, either calculating the second derivatives incorrectly or allowing gradients to accumulate in some way. Does anyone see why this code might be doing this?