How to compute batched vector jacobian product?


import torch
import torch.nn as nn
import time
from congrad.torch import cg_batch
from torch.func import grad, vmap, vjp, functional_call

# Check if GPU is available and set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define a larger neural network with 5 layers and 512 hidden units
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(24, 512)  # Input size is the sum of observation and action dimensions
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 512)
        self.fc5 = nn.Linear(512, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        return self.fc5(x)

# Initialize the neural network and move it to the GPU
net = NeuralNet().to(device)
params = dict(net.named_parameters())

# Mock data, moved to GPU
torch.manual_seed(0)  # For reproducibility
batch_size = 1024
x = torch.rand(batch_size, 24, device=device)  # 1024 observations, 17 features each

# Define the necessary functions
def fcall(params, x):
    return functional_call(net, params, x.unsqueeze(0)).squeeze()

def compute_gradient_and_its_copy(params, x):
    gradient = grad(fcall, argnums=1)(params, x)
    return gradient, gradient

def compute_vjp_fn_and_gradient(params, x):
    _, vjp_fn, gradient = vjp(lambda x: compute_gradient_and_its_copy(params, x), x, has_aux=True)
    return vjp_fn, gradient

with torch.no_grad():
    vjp_fn, gradient = vmap(compute_vjp_fn_and_gradient, in_dims=(None, 0))(params, x)
    solution, solve_info = cg_batch(lambda x: vjp_fn(x)[0], gradient)

I am trying to compute a batch of vector jacobian products and use them to solve the linear equation Hess x=grad with the conjugate gradient method. This approach avoids the computation of the Hessian matrix and should, hopefully, be faster. However, vmap does not support return a batch of functions. I am wondering if there is any workaround for this.

Thanks in advance for your precious time!