HVP w.r.t model parameters

Thanks a lot, it seems to be working!

I only had to change

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)

to

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), torch.nn.Parameter(p))

Because otherwise, after calling make_functional, and then f, the model parameters were not working properly and, when calling vhp a second time, I get

v is a tuple of invalid length: should be 0 but got 2.

Which I guess is due to the fact that when we set the weights back, they’re not set as Parameters. Does it make sense?