i am trying to use torch.autograd.functional
vhp function with respect to model parameters. i defined a loss function that takes the model parameters as input but I am having trouble loading the parameters into the model (without inadvertently copying them).
given some parameter p, when i use setattr(model, param_name, p)
, i get an error that says p needs to be wrapped in a Parameter. when i wrap it in a Parameter with setattr(model, param_name, torch.nn.Parameter(p))
, p is copied so the reference to the original p is lost and vhp returns 0.
this is related to this old thread: HVP w.r.t model parameters - #5 by dedeswim
thanks!