Thanks a lot, it seems to be working!
I only had to change
def load_weights(mod, names, params):
for name, p in zip(names, params):
set_attr(mod, name.split("."), p)
to
def load_weights(mod, names, params):
for name, p in zip(names, params):
set_attr(mod, name.split("."), torch.nn.Parameter(p))
Because otherwise, after calling make_functional
, and then f
, the model parameters were not working properly and, when calling vhp a second time, I get
v is a tuple of invalid length: should be 0 but got 2.
Which I guess is due to the fact that when we set the weights back, they’re not set as Parameter
s. Does it make sense?