I want to implement some changes in the adam optimizer which requires second derivatives along with the first derivatives (gradients) of loss.
But due to loss.backwards implementation, I am unable to calculate second derivatives. Is there any way I achieve so?
Thank You.
Hi @Shubh_Goyal,
You can have a look at torch.autograd.functional.hessian
torch.autograd.functional.hessian — PyTorch 2.3 documentation to compute 2nd order derivatives (although it’s pretty computationally expensive).
Or you could have a look at the torch.func
namespace and compute higher-order gradients there: torch.func API Reference — PyTorch 2.3 documentation
You can then pass these gradients to the step()
method of your custom optimizer.
Hey @AlphaBetaGamma96
Is it not necessary to do loss.backward().
Like what I can understand from the second reference you gave is that I can compute the first gradients as well as second gradients through grad function of torch.func and pass it while taking step in optimizer??
Yes, you can define your own custom torch.optim.Optimizer
object, which takes the second-order derivatives as an input in the optim.step()
method. Then you can define the optimizer however you want with the second-derivative terms.
Hey @AlphaBetaGamma96
I found this article. Do you suggest doing this. This is completely functional programming. But the torchoptim library used is not official from pytorch I suppose.
Can you guide me more on how to use the func.grad with out torch models.
It would be of great help.
For the torch.func
library, I’d recommend reading the docs here
And for the custom torch.optim.Optimizer
object, I’d recommend reading the docs here
Hey, I have read the docs and wrote some code, but the loss is not at all decreasing, I didn’t change the Adam formula yet. Can you look at my code?
I checked the calculated gradients and they are all zero.
If you share a minimal reproducible example, I can have a look.
Also, it should be noted that torch.func
returns gradients from a function, whereas torch.autograd
returns them to the param.grad
attribute. So, you might just need to attach them the .grad
attribute of your parameters.
1 Like
# Function to compute the loss
def compute_loss(data, target, params):
output = model_func(data, params)
loss = loss_fn(output, target)
return loss
for i in range(num_epochs):
# Compute the gradients using torch.func.grad
loss = compute_loss(x_train, y_train, params)
grads = grad(compute_loss, argnums=2)(x_train, y_train, params)
first_order_grads = grad(compute_loss, argnums=2)(x_train, y_train, params)
# Compute second-order gradients
def compute_gradients(params):
loss = compute_loss(x_train, y_train, params)
grads = grad(compute_loss, argnums=2)(x_train, y_train, params)
return grads
# Use the first-order gradients to compute second-order derivatives
second_order_grads = grad(lambda p: torch.sum(torch.cat([g.view(-1) for g in compute_gradients(p)])), argnums=0)(params)
Do you think this is correct?
I’d advise you use torch.func.hessian
as it’ll be significantly more efficient than an torch.autograd
approach, I have an example of the forums here: Efficient computation of Hessian with respect to network weights using autograd.grad and symmetry of Hessian matrix - #8 by AlphaBetaGamma96
Hey. I did go through it. Is there any efficient way to do it. I can’t use it for fully connected layers at all.
torch.func
works directly on torch primitives, so nn.Linear
layers are supported. Just make sure to use torch.func.functional_call
when interacting with nn.Module
objects to create a ‘functional’ version.
Actually what I meant is that the number of parameters is a problem for me. I have computational constraint so am unable to calculate hessian.
The Hessian scales quadratically in the number of parameters, so it can become intractable quite quickly. How many parameters does your model have?
Also, if you’re going to multiply the Hessian with a vector (which might not necessarily be the case), it’d be best to directly compute the Hessian-vector product, which avoids computing the Hessian directly and scales much better than naively computing the Hessian then multiplying it by a vector (tutorial here: Jacobians, Hessians, hvp, vhp, and more: composing function transforms — PyTorch Tutorials 2.4.0+cu124 documentation).