Dear community, does anyone managed to compute a hessian of an arbitrary autograd function?
Let’s consider simple quadratic function: f(x) = x.T A x, it’s gradient is (A + A.T)x and hessian is just (A + A.T).
I have two code snippets, first works:
import torch
from torch import Tensor
from torch.autograd import Variable
from torch.autograd import grad
from torch import nn
torch.manual_seed(623)
x = Variable(torch.ones(2,1), requires_grad=True)
A = torch.FloatTensor([[1,2],[3,4]])
print(A)
print(x)
f = x.view(-1) @ A @ x
print(f)
x_1grad, = grad(f, x, create_graph=True)
print(x_1grad)
print(A @ x + A.t() @ x)
x_2grad0, = grad(x_1grad[0], x, create_graph=True)
x_2grad1, = grad(x_1grad[1], x, create_graph=True)
Hessian = torch.cat((x_2grad0, x_2grad1), dim=1)
print(Hessian)
print(A + A.t())
while the second does not work.
import torch
from torch import Tensor
from torch.autograd import Variable
from torch.autograd import grad
from torch import nn
torch.manual_seed(623)
class Quadro(torch.autograd.Function):
def __init__(self, A):
self.A = torch.FloatTensor(A)
def forward(self, input):
self.save_for_backward(input)
return input.t() @ self.A @ input
def backward(self, grad_output):
input, = self.saved_tensors
input = torch.autograd.Variable(input, requires_grad = True)
grad_input = torch.mul(grad_output, (self.A + self.A.t()) @ input)
return grad_input
x = Variable(torch.ones(2,1), requires_grad=True)
A = torch.FloatTensor([[1,2],[3,4]])
print(A)
print(x)
qq = Quadro(A)
f = qq(x)
print(f)
x_1grad, = grad(f, x, create_graph=True, retain_graph=True)
print(x_1grad)
print(A @ x + A.t() @ x)
x_2grad0, = grad(x_1grad[0], x, create_graph=True)
x_2grad1, = grad(x_1grad[1], x, create_graph=True)
Hessian = torch.cat((x_2grad0, x_2grad1), dim=1)
print(Hessian)
print(A + A.t())
The problem is that I cant differentiate the first derivative, since it has no grad_fn attribute, but I don’t know how to fix it.
P.S. I completely understand, that PyTorch framework was mostly done to train neural networks, not such simple things, but I would be happy, if someone could help me.