Set parameters to model without breaking autograd

I am trying to do this:

def closure(params):
    for p, new_p in zip(optimizer_params, params):
        p.set_(new_p) # this errors when using torch.func.jvp
    preds = model(inputs)
    loss = loss_fn(preds, targets)
    return loss

torch.func.jvp(closure, params, vec)

The way this could be done is via torch.func.functional_call. This way I can do

def closure(params):
    params_dict = {k: p for (k, v), p in zip(model.named_parameters(), params) if v.requires_grad}
    preds = torch.func.functional_call(model, params_dict, (inputs, ))
    loss = loss_fn(preds, targets)
    return loss

torch.func.jvp(closure, params_tuple, vec)

However it would be more convenient if I could somehow use the first way, because second one requires me to use a torch.nn.Module instead of arbitrary list of tensors, I have to rewrite my closure and mess with named_parameters

I found this method here pytorch-minimize/pytorch_minimize/optim.py at master · gngdb/pytorch-minimize · GitHub

def f(x):
                        _params = self.unravel_unpack(x)
                        # monkey patch substitute variables
                        named_params = list(model.named_parameters())
                        for _p, (n, _) in zip(_params, named_params):
                            rdelattr(model, n)
                            rsetattr(model, n, _p)
                        return closure.loss(model)
                    def numpyify(x):
                        if x.device != torch.device('cpu'):
                            x = x.cpu()
                        #return x.numpy().astype(np.float64)
                        return self.floatX(x.numpy())
                    return numpyify(torch.autograd.functional.hessian(f, x))

I am afraid that will break references that other things have to those parameters, and it still requires a torch.nn.Module. But maybe something like this is possible without modifying attributes?

Maybe something like torch.utils.swap_tensors — PyTorch 2.5 documentation would be helpful?

thank you, this actually worked! Here is a minimal example of how to make this work

import torch, torch.nn as nn

model = nn.Sequential(nn.Linear(2, 3), nn.ReLU(), nn.Linear(3, 2))
params = tuple(model.parameters())
inputs = torch.randn(2)
targets = torch.randn(2)

def closure():
    """a typical closure but without backward"""
    preds = model(inputs)
    loss = nn.functional.mse_loss(preds, targets)
    return loss

def param_closure(*new_params):
    """closure that takes in params as inputs to work with torch.func.jvp.
    For example an optimizer can create this from normal closure"""
    # swap params to new params
    for old_p, new_p in zip(params, new_params):
        torch.utils.swap_tensors(old_p, new_p)

    value = closure()

    # swap params back to original ones
    for old_p, new_p in zip(params, new_params):
        torch.utils.swap_tensors(old_p, new_p)

    return value

tangents = tuple(torch.randn_like(i) for i in params)
torch.func.jvp(param_closure, primals = tuple(i.clone().detach_() for i in params), tangents = tangents)

Two things though, first I have to clone and detach the parameters before passing them to the closure, which is a tiny bit inefficient. Second thing is that on a small 1D convnet this doesn’t seem to be any faster than calling backward and calculating the dot product of gradient with my vector, and on a (badly optimized) box packing task this is actually 2 times slower than backward. I haven’t tested anything large though.