HVP w.r.t model parameters

Hi!

I have seen that within the 1.5.0 release, the possibility to compute HVP of a function has been added.

As far as I understand from the documentation, the HVP (as well as the VHP, VJP and so on) can be computed w.r.t. the input only, and not w.r.t. some other variable (such as, for instance, the parameters of a model). Is my understanding correct? Is there any way to use the built-in HVP (or VHP) function differentiating w.r.t. the model parameters?

Thanks!

Hi,

They calculate the gradient wrt of the input of the function you give. So if your functions takes as input the parameters you want and evaluate it on a fixed input. THen it will compute the value wrt to the parameters.

1 Like

Hi,

Thanks for your helpful reply! So, in case I have a model model, and I want to compute the VHP of the loss w.r.t. the parameters, how do I create a function that computes the loss and takes the parameters as input? Right now I have:

f = lambda x: criterion(model(img), label)
vhps = [vhp(f, params, grad)[1] for params, grad in zip(model.parameters(), grads)]

Where the model is a logistic regression trained for MNIST, so its parameters are a list of this shape [(10, 784), (10)]. Instead, grads is the gradient w.r.t. to the parameters, so it has the same shape as the params.

However, with this, I get a list of the same shape as the parameters, but it is all zeros, which makes me suppose that the function doesn’t really depend on model.parameters(). Do you have any suggestions? I’d like the solution to be scalable to an kind of model, so creating a function that uses the parameters to compute the logistic regression by hand wouldn’t be much helpful.

Thanks in advance!

Hi,

You don’t have to do each parameter one by one, you can give all the params/grads as tuples.

You get all zeros because your function f does not use the inputs x to compute the output.
You can do something like this to use the autograd API with torch.nn:

# Utilities to make nn.Module functional
def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(mod):
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append(name)
    return orig_params, names

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)

N = 10
model = models.resnet18(pretrained=False)
criterion = torch.nn.CrossEntropyLoss()

params, names = make_functional(model)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p.detach().requires_grad_() for p in params)

inputs = torch.rand([N, 3, 224, 224], device=device)
labels = torch.rand(N, device=device).mul(10).long()
def f(*new_params):
    load_weights(model, names, new_params)
    out = model(inputs)

    loss = criterion(out, labels)
    return loss

vhp(f, params, grads)
1 Like

Thanks a lot, it seems to be working!

I only had to change

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)

to

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), torch.nn.Parameter(p))

Because otherwise, after calling make_functional, and then f, the model parameters were not working properly and, when calling vhp a second time, I get

v is a tuple of invalid length: should be 0 but got 2.

Which I guess is due to the fact that when we set the weights back, they’re not set as Parameters. Does it make sense?

This is expected!
Parameters are only leaf Tensors, but here your new module works with any input that is given. They can be leafs or not.

Making these Parameters might break the graph with the other layers above, you should not do that.

1 Like

I have just realized that, actually, using torch.nn.Parameter makes the VHP 0 again (and with strict=True an exception is raised.

The problem is solved if I load the params with torch.nn.Parameter after the computation.

So I changed load_weights to

def load_weights(model, names, params, as_params=False):
    for name, p in zip(names, params):
        if not as_params:
            set_attr(model, name.split("."), p)
        else:
            set_attr(model, name.split("."), torch.nn.Parameter(p))

And I call

load_weights(model, names, params, as_params=True)

I have just read your answer, though, and I guess this would be wrong too, right?

Do you have any suggestion to have back the nn model after making it functional?

Thanks!

Loading them back as nn.Paramter() looks good yes to recover the original nn.Module back.
Just be careful that you need to call make_functional again before being able to use it in a functional way again.

1 Like

Sure, thank you very much!

This is a little old, but I think others may have this same question.

Can you explain why it is necessary to remove the parameters from the model prior to calling vhp? I have been playing with your example, and I can see that yes, it is necessary. Furthermore, I see what you did to restore your model to usefulness using your newest load_weights(). But I would like to understand why removing the parameters from the model is necessary in the first place.

Thanks very much!

@dedeswim i am getting the same error that hvp is 0 (and the exception thrown). i think it’s because once p is casted to a Parameter, it is copied so the reference to the original p is lost. did you find a way around this?

parameter = torch.nn.Parameter(tensor)
print(parameter.data)
tensor = torch.zeros(tensor.shape)
print(parameter.data)

here ^ print outputs the same result

I have similar confusion here.

Note that today, you can simply use torch.func.functional_call() to achieve this!