How can I use torch's `VHP` routine inside of the `step` method when defining a custom optimizer?

So I have been playing around with implementing an optimizer that leverages the hessian-vector product in it’s update rule inside of torch — I am keen on using the existing VHP for the same, but it’s really awkward API makes it impossible to be used inside of the step method when defining a custom optimizer class in torch.

Here are the problems that I currently have with the existing VHPinterface:

  1. One typically does not pass in any argument to the stepmethod when making the call to optimizer.step(), except in the case of algorithms like L-BFGS which require a closureargument since they re-evaluate the loss function multiple times every iteration. Given the API for the VHP, a typical call to the method would look something like: VHP(loss_func, (model_output_tensor, gt_labels/values_tensor), v) (I have gripes about this signature as well, which I explain in the next point). But step does not typically have access to any of these arguments! step doesn’t have any idea about loss_func, model_output_tensor, gt_labels — so just to get it to work I decided to dump all of these arguments inside of closure (horrible choice, but that seemed to be the only way out) but then…
  2. It seems like the way VHP treats the inputs tuple is as separate inputs to the function that are to be evaluated independently (that is why they explicitly impose len(inputs) == len(v)), but loss functions almost always take in vector inputs and I have no clue how one could pass in two vector inputs in a single tensor (except maybe by vertically stacking the inputs, but I am unsure if VHP would parse the inputs that way).

TLDR: The API for torch’s VHP method is really hard to work with when defining a custom optimizer in torch — are there any interesting alternatives for computing the Hessian-Vector-Product that do not run into the pitfalls that I mentioned above? As far as I am aware JAX also has an identical API for their own HVP method.

Thank you for your time!

(1) looks more like a complaint to the optimizer interface haha, but using a closure unfortunately is what you have to do yes
(2) I’m not sure I understand. Isn’t hessian inputs x inputs, and so you need v to also be inputs in order to compute vhp? so that is why len(inputs) == len(v) is necessary.

  1. Haha, I agree XD.
  2. For this point, consider the example of the following function:
def test_func_multi(x: torch.Tensor, y:torch.Tensor):
    return (2 * y ** 3 - 3 * x ** 2).sum()

I do not understand how to use the existing VHP API for this (and it doesn’t help that the documentation doesn’t have such an example) — like, I would want to pass in vector inputs for both x and y in this case, but if I were to pass in the same as a tuple as: VHP(test_func_multi, (x_in, y_in), v) where the semantics I hope for are that (x_in, y_in) are used for instantiating test_func_multi, but this throws an error about the lengths of the inputs and the vector not being equal. I am interested in a function like this because that is the signature of every loss function.

One way around this could be to partially instantiate the test_func_multi function using functools.partial (since for a loss function, the labels are anyways tensors that have requires_grad=False)

I had a separate question as well, I would be really thankful if you could help :).

Since the posting of this question, I have figured out a way to compute the HVP for my use-case inside of the step method without having to resort to passing in additional arguments inside of closure (through some autograd magic XD). The optimizer seems to run fine for synthetic experiments (like on static functions e.g. rosenbrock etc.), but it seems to throw a really opaque error after running for a few iterations when running for simple NN’s (in this case a simple CNN over CIFAR10).

This is the error that I am getting:

E               RuntimeError: A view was created in no_grad mode and is being modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone, it is forbidden. You can clarify your code by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked).

I found this relevant torch issue (where I believe you have also replied to the conversation?)

I get this error on a line of the step method where I am making updates to a state variable of the optimizer (this particular state variable also happens to be involved in the HVP) — and interestingly, I get this error after the algorithm has already run for a few iterations, making this really hard to debug. The update is very simple as well, it’s just T[T > constant] = constant (where T is the state variable in question).

Would you have any comments for what I could do in this situation?

Thank you so much for your help!

This might be a dumb question, but why won’t you have two v’s in your example?

Is there any reason the inplace update to state needs to be done in grad mode? Doing the update to state in no-grad would be one way of resolving this.