import torch
import torch.nn as nn
import time
from congrad.torch import cg_batch
from torch.func import grad, vmap, vjp, functional_call
# Check if GPU is available and set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a larger neural network with 5 layers and 512 hidden units
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(24, 512) # Input size is the sum of observation and action dimensions
self.fc2 = nn.Linear(512, 512)
self.fc3 = nn.Linear(512, 512)
self.fc4 = nn.Linear(512, 512)
self.fc5 = nn.Linear(512, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = torch.relu(self.fc4(x))
return self.fc5(x)
# Initialize the neural network and move it to the GPU
net = NeuralNet().to(device)
params = dict(net.named_parameters())
# Mock data, moved to GPU
torch.manual_seed(0) # For reproducibility
batch_size = 1024
x = torch.rand(batch_size, 24, device=device) # 1024 observations, 17 features each
# Define the necessary functions
def fcall(params, x):
return functional_call(net, params, x.unsqueeze(0)).squeeze()
def compute_gradient_and_its_copy(params, x):
gradient = grad(fcall, argnums=1)(params, x)
return gradient, gradient
def compute_vjp_fn_and_gradient(params, x):
_, vjp_fn, gradient = vjp(lambda x: compute_gradient_and_its_copy(params, x), x, has_aux=True)
return vjp_fn, gradient
with torch.no_grad():
vjp_fn, gradient = vmap(compute_vjp_fn_and_gradient, in_dims=(None, 0))(params, x)
solution, solve_info = cg_batch(lambda x: vjp_fn(x)[0], gradient)
I am trying to compute a batch of vector jacobian products and use them to solve the linear equation Hess x=grad with the conjugate gradient method. This approach avoids the computation of the Hessian matrix and should, hopefully, be faster. However, vmap
does not support return a batch of functions. I am wondering if there is any workaround for this.
Thanks in advance for your precious time!