I have seen that within the 1.5.0 release, the possibility to compute HVP of a function has been added.

As far as I understand from the documentation, the HVP (as well as the VHP, VJP and so on) can be computed w.r.t. the input only, and not w.r.t. some other variable (such as, for instance, the parameters of a model). Is my understanding correct? Is there any way to use the built-in HVP (or VHP) function differentiating w.r.t. the model parameters?

They calculate the gradient wrt of the input of the function you give. So if your functions takes as input the parameters you want and evaluate it on a fixed input. THen it will compute the value wrt to the parameters.

Thanks for your helpful reply! So, in case I have a model model, and I want to compute the VHP of the loss w.r.t. the parameters, how do I create a function that computes the loss and takes the parameters as input? Right now I have:

f = lambda x: criterion(model(img), label)
vhps = [vhp(f, params, grad)[1] for params, grad in zip(model.parameters(), grads)]

Where the model is a logistic regression trained for MNIST, so its parameters are a list of this shape [(10, 784), (10)]. Instead, grads is the gradient w.r.t. to the parameters, so it has the same shape as the params.

However, with this, I get a list of the same shape as the parameters, but it is all zeros, which makes me suppose that the function doesn’t really depend on model.parameters(). Do you have any suggestions? I’d like the solution to be scalable to an kind of model, so creating a function that uses the parameters to compute the logistic regression by hand wouldn’t be much helpful.

You don’t have to do each parameter one by one, you can give all the params/grads as tuples.

You get all zeros because your function f does not use the inputs x to compute the output.
You can do something like this to use the autograd API with torch.nn:

# Utilities to make nn.Module functional
def del_attr(obj, names):
if len(names) == 1:
delattr(obj, names[0])
else:
del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
if len(names) == 1:
setattr(obj, names[0], val)
else:
set_attr(getattr(obj, names[0]), names[1:], val)
def make_functional(mod):
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
del_attr(mod, name.split("."))
names.append(name)
return orig_params, names
def load_weights(mod, names, params):
for name, p in zip(names, params):
set_attr(mod, name.split("."), p)
N = 10
model = models.resnet18(pretrained=False)
criterion = torch.nn.CrossEntropyLoss()
params, names = make_functional(model)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p.detach().requires_grad_() for p in params)
inputs = torch.rand([N, 3, 224, 224], device=device)
labels = torch.rand(N, device=device).mul(10).long()
def f(*new_params):
load_weights(model, names, new_params)
out = model(inputs)
loss = criterion(out, labels)
return loss
vhp(f, params, grads)

def load_weights(mod, names, params):
for name, p in zip(names, params):
set_attr(mod, name.split("."), p)

to

def load_weights(mod, names, params):
for name, p in zip(names, params):
set_attr(mod, name.split("."), torch.nn.Parameter(p))

Because otherwise, after calling make_functional, and then f, the model parameters were not working properly and, when calling vhp a second time, I get

v is a tuple of invalid length: should be 0 but got 2.

Which I guess is due to the fact that when we set the weights back, they’re not set as Parameters. Does it make sense?

I have just realized that, actually, using torch.nn.Parameter makes the VHP 0 again (and with strict=True an exception is raised.

The problem is solved if I load the params with torch.nn.Parameter after the computation.

So I changed load_weights to

def load_weights(model, names, params, as_params=False):
for name, p in zip(names, params):
if not as_params:
set_attr(model, name.split("."), p)
else:
set_attr(model, name.split("."), torch.nn.Parameter(p))

Loading them back as nn.Paramter() looks good yes to recover the original nn.Module back.
Just be careful that you need to call make_functional again before being able to use it in a functional way again.