def closure(params):
for p, new_p in zip(optimizer_params, params):
p.set_(new_p) # this errors when using torch.func.jvp
preds = model(inputs)
loss = loss_fn(preds, targets)
return loss
torch.func.jvp(closure, params, vec)
The way this could be done is via torch.func.functional_call. This way I can do
def closure(params):
params_dict = {k: p for (k, v), p in zip(model.named_parameters(), params) if v.requires_grad}
preds = torch.func.functional_call(model, params_dict, (inputs, ))
loss = loss_fn(preds, targets)
return loss
torch.func.jvp(closure, params_tuple, vec)
However it would be more convenient if I could somehow use the first way, because second one requires me to use a torch.nn.Module instead of arbitrary list of tensors, I have to rewrite my closure and mess with named_parameters
def f(x):
_params = self.unravel_unpack(x)
# monkey patch substitute variables
named_params = list(model.named_parameters())
for _p, (n, _) in zip(_params, named_params):
rdelattr(model, n)
rsetattr(model, n, _p)
return closure.loss(model)
def numpyify(x):
if x.device != torch.device('cpu'):
x = x.cpu()
#return x.numpy().astype(np.float64)
return self.floatX(x.numpy())
return numpyify(torch.autograd.functional.hessian(f, x))
I am afraid that will break references that other things have to those parameters, and it still requires a torch.nn.Module. But maybe something like this is possible without modifying attributes?
thank you, this actually worked! Here is a minimal example of how to make this work
import torch, torch.nn as nn
model = nn.Sequential(nn.Linear(2, 3), nn.ReLU(), nn.Linear(3, 2))
params = tuple(model.parameters())
inputs = torch.randn(2)
targets = torch.randn(2)
def closure():
"""a typical closure but without backward"""
preds = model(inputs)
loss = nn.functional.mse_loss(preds, targets)
return loss
def param_closure(*new_params):
"""closure that takes in params as inputs to work with torch.func.jvp.
For example an optimizer can create this from normal closure"""
# swap params to new params
for old_p, new_p in zip(params, new_params):
torch.utils.swap_tensors(old_p, new_p)
value = closure()
# swap params back to original ones
for old_p, new_p in zip(params, new_params):
torch.utils.swap_tensors(old_p, new_p)
return value
tangents = tuple(torch.randn_like(i) for i in params)
torch.func.jvp(param_closure, primals = tuple(i.clone().detach_() for i in params), tangents = tangents)
Two things though, first I have to clone and detach the parameters before passing them to the closure, which is a tiny bit inefficient. Second thing is that on a small 1D convnet this doesn’t seem to be any faster than calling backward and calculating the dot product of gradient with my vector, and on a (badly optimized) box packing task this is actually 2 times slower than backward. I haven’t tested anything large though.