Hessian product slower in forward mode than reverse mode

Hi, I use forward mode differentiation and I get the same gradients as reverse mode but much slower. The bottlneck of my code is the Hessian vector product which uses a for loop, is there a more efficient way of doing it?

D = n_weights 
K = n_hyperparams
Z = torch.zeros((D, K)) # input that changes through time
HZp = torch.zeros((D, K)) # Hessian product

grads = torch.autograd.grad(loss, weights, create_graph=True)[0] # size (D)
for k in range(K):
    HZp[:,k] = torch.autograd.grad(grads @ Z[:,k], weights, retain_graph=True)[0]

Note that for K=1 forward mode takes the same time as reverse mode. However it becomes ~ K times slower with increasing K. Thanks for any insight :slight_smile:

I use forward mode differentiation

Pytorch does not have forward mode differentiation. You mean you use double backward trick to get a jacobian vector product?

I’m not sure to understand what you compare the runtime with? Can you give a full code sample?

Thanks for your time.

The code is very long and our paper isn’t out yet but hopefully I can bring some clarifications instead. When doing backpropagation through time for gradient based hyperparameter optimization, forward mode differentiation is used because its memory cost scales as O(KD) instead of O(TD) for reverse mode, where K=n_hyperparameters, D=n_weights, T=n_iterations. Here my goal is to get the (hyper)gradient of some validation loss after T steps with respect to K hyperparameters used in those steps.

In practice this means we keep updating a matrix Z of shape (D,K) for each step, such that Z_t=dparams_t/dhyperparams. This update rule requires the Hessian product shown above (see Eqt (11) and (14) of this paper if interested http://proceedings.mlr.press/v70/franceschi17a/franceschi17a.pdf) which accounts for most of the compute time. Once we have Z_T at the final iteration we can easily get the hypergradients desired with the chain rule.

My feeling is that since I am doing exactly the same operations as the reverse mode but in a different order, I should be able to get similar times.

For small enough T, I can use Pytorch’s autograd by simply making my hyperparameters (e.g. learning rate) differentiable. This is what I am comparing the above with.

So you manually compute the forward mode gradients in Z?
But if you compute the forward mode gradient of the model that gives you the loss, then Z should be the same size as loss which is just 1 element.

Sorry no, at each step I use reverse mode for the gradients of the training loss w.r.t model params, and I use that to compute Z. However I use forward mode manually by using the Z at the final step, to get the final (hyper)gradient of interest. So Z is an intermediate quantity that I need to compute the final hypergradient at step T, which is dL_val/dhyperparams = (dLval/dparams_T) * (dparams_T/dhyperparams) = (dLval/dparams) * Z_T

If you’re familiar with MAML in meta learning, this is like using reverse mode for the inner loop, and forward mode for the outer loop

I am still confused as computing the product (dLval/dparams) * Z_T is a forward mode AD computation on the model.
If you wanted to do that you would have to use backward of backward trick in pytorch.
Or compute the full Jacobian of the mode, but that is going to be expensive and expected to be.

Perhaps forward mode is used loosely in this literature; what I mean here is that, contrary to reverse mode, I don’t have to store the all instances of my weights in memory to compute the final hypergradient. It does not refer to forward mode on dLtrain/dparams.

My question is: in the snippet above, can this sort of operation (for loop + writing one col at a time) be turned into a single matrix multiplication and single call to autograd.grad?

This is slightly slower on CPU…

grads = torch.autograd.grad(loss, weights, create_graph=True)[0] # size (D)
grads_times_Z = torch.mm(grads.unsqueeze(1).T, Z)
for k in range(K):
    HZp[:,k] = torch.autograd.grad(grads_times_Z[:,k], weights, retain_graph=True)[0]

The problem is not the product but that you only want to gradients corresponding to part of it. In some sense, you have K different loss functions. So if you want the gradients for each of them you need to do K backwards. I don’t think there is any way around this.