Hi, I use forward mode differentiation and I get the same gradients as reverse mode but much slower. The bottlneck of my code is the Hessian vector product which uses a for loop, is there a more efficient way of doing it?
D = n_weights K = n_hyperparams Z = torch.zeros((D, K)) # input that changes through time HZp = torch.zeros((D, K)) # Hessian product grads = torch.autograd.grad(loss, weights, create_graph=True) # size (D) for k in range(K): HZp[:,k] = torch.autograd.grad(grads @ Z[:,k], weights, retain_graph=True)
Note that for K=1 forward mode takes the same time as reverse mode. However it becomes ~ K times slower with increasing K. Thanks for any insight