Hessian of output with respect to inputs

Is there any way to find the full hessian matrix of the inputs with respect to the outputs in a simple feed-forward fully-connected network? I know you can call .backwards twice to get the hessian-vector product, but is there an efficient way (without using a loop) to find the full matrix? Thanks in advance!

1 Like

Hi,

You can find an example in this gist of how to implement such function: https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7

2 Likes

Thanks for your quick response @albanD.

I have been implementing something very similar to the example you’ve posted here. What I’m wondering, however, is if there is some way to find the full hessian of the output with respect to the inputs with a single call to torch.autograd.grad?

I’m afraid not.
This is a limitation of backpropagation and automatic differentiation in general. It does not allow you to compute a Jacobian, but a vector Jacobian product. This you need multiple calls to re-recreate the full Jacobian. (similarly for the Hessian).

1 Like

Thanks for confirming that this is the case!

Hi,

I’m using the same code by Adam to calculate the Hessian of a simple feed-forward fully-connected network’s outputs with respect to inputs. But I got a tensor full of zeros as hessian. Do you have any idea? Any help you could offer would be greatly appreciated!

Thanks,
Li

Hi,

This is most likely because your function is Linear and so the Jacobian is constant and the hessian is 0 no?
Can you share your function?

Hi,

I am working on the same problem, I used the code of the link: https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7, which you posted earlier at this post.

I am also using a straight forward neuronal network as a toy example:

def jacobian(y, x, create_graph=True):                                                               
    jac = []                                                                                          
    flat_y = y.reshape(-1)
    #print("flat_y: ", flat_y)  

    grad_y = torch.zeros_like(flat_y)   
    #print("grad_y: ", grad_y)    

    for i in range(len(flat_y)):                                                                      
        grad_y[i] = 1.                                                                                
        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
        #grad_x, = torch.autograd.grad(flat_y, x,retain_graph=True, create_graph=create_graph)
        jac.append(grad_x.reshape(x.shape))                                                           
        grad_y[i] = 0.                                                                                
    return torch.stack(jac).reshape(y.shape + x.shape)                                                
   
                                                                                                      
def hessian(y, x):                                                                                    
    return jacobian(jacobian(y, x, create_graph=True), x) 

if __name__ == '__main__':
    x = torch.ones(4, requires_grad=True) 
    y = Variable(Tensor([1.]), requires_grad=True)

    model = nn.Linear(4, 1)
    model.weight.data.fill_(0.5)
    model.bias.data.fill_(0.5)

    loss = torch.sum((y - model(x))**2)
    pred = model(x)
 
    print("type", loss)
    print("type(loss)", type(loss))
    print("loss.shape", loss.shape)
    print(hessian(loss, x))
    print("\n") 

    print("pred", pred[0])
    print("type(pred)", type(pred[0]))
    print("pred.shape", pred[0].shape)
    print("\n")   
    print(hessian(pred[0], x))     


And get the following result and error:

loss: tensor(2.2500, grad_fn=< SumBackward0 >)
type(loss): <class ‘torch.Tensor’>
loss.shape: torch.Size([])
tensor([[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000]], grad_fn=< ViewBackward >)

pred: tensor(2.5000, grad_fn=< SelectBackward >)
type(pred): <class ‘torch.Tensor’>
pred.shape: torch.Size([])

Traceback (most recent call last):
File “test_hfm.py”, line 58, in
print(hessian(pred[0], x))
File “test_hfm.py”, line 27, in hessian
return jacobian(jacobian(y, x, create_graph=True), x)
File “test_hfm.py”, line 19, in jacobian
grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
File “/Users/felixasanger/anaconda3/envs/dil_bmw/lib/python3.7/site-packages/torch/autograd/init.py”, line 158, in grad
inputs, allow_unused)
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Why does everything work fine for the loss, but not if I try it with the prediction? Is there any difference?
Any help is highly apperciated, I read a lot about this problem, but I just can´t solve it.

Best regards and thanks a lot,

Felix

Hi,

Same answer as above :smiley:
The function that gives your prediction given the weights is linear (litterally just a Linear layer). And so the Jacobian is constant and so independent of the weights.
Depending on the cases (internal implementation weirdness) we either return a matrix full of 0 and a message saying that the two are independent (meaning that the gradient is 0 everywhere).

Your loss function add a quadratic term and so makes the hessian non-zero. So it works fine.

Hi,

thanks for your answer, but maybe you misunderstood my question. I didn´t ask why the Hessian would be 0. I was asking why I get an error, when I want to compute the Hessian of my prediction (NOT the loss) with respect to my input x.

I don´t get the error when I do pred = model(x)**1, then it is possible to compute the Hessian of pred with respect to the input. Do you have any idea why it only works when I add the **1 ?

Best,
Felix

Ho, the answer is that because of how pytorch works, a gradient of 0 can be represented in 3 ways:

  • A Tensor full of 0s
  • None being returned by autograd.grad(…, allow_unused=True) or the in the .grad field after .backward()
  • An error with autograd.grad(…, allow_unused=False)

Note that we have no way to differentiate “is independent of” and “has a gradient of 0” and both can be represented in the three ways above.

Hi albanD,

Actually I am also trying to compute hessian of a CNN model’s prediction output w.r.to input vector.

Let’s assume simple task of MNIST classification. We have 1x1x28x28 input vector and 1x10 output tensor

Now my aim is to compute the hessian of a specific class’s pred value w.r.to input .

I want to calculate Jacobian in the dimension of 784x1. For example, calculating Jacobian for classs 2’s pred value should yield a tensor of dim 784x1.

And calculating Hessian for class 2’s pred value should yield a matrix of dimension 784 x 784.

I tried to use the code you have provided but I couldn’t manage to change to the one I need.

Can you please suggest a way to calculate a Hessian matrix with dimension 784 x 784 for this specific case?Assuming I need to calculate hessian for prediction output value of a specific class.

Hi,

If you use the functions in torch.autograd.functional to do this, then you will get a hessian of size input_size + input_size. So in your case 1x1x28x28x1x1x28x28. But you can use .view() to make that 784x784 if that’s what you want. You just collapse the dimensions as if you were using a function with input size 784.

Hi,

I used below code to compute hessian matrix for a specific class. It seems to work but I need to double check. In below example I calculated the hessian matrix of output prediction for class 2 with respect to input tensor. I got an 4x4 hessian matrix in this way.

Can you please comment?

import torch

def jacobian(y, x, create_graph=False):
    jac = []
    flat_y = y.reshape(-1)
    grad_y = torch.zeros_like(flat_y)
    for i in range(len(flat_y)):
        grad_y[i] = 1.
        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
        jac.append(grad_x.reshape(x.shape))
        grad_y[i] = 0.
    return torch.stack(jac).reshape(y.shape + x.shape)

def jacobian2(y, x, t, create_graph=False):
    flat_y = y.reshape(-1)
    grad_y = torch.zeros_like(flat_y)
    grad_y[t] = 1.
    grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph,allow_unused=True)
    return grad_x


def hessian(y, x, t):
    return jacobian(jacobian2(y, x, t, create_graph=True), x)


def f(x):
    return x * x * torch.arange(4, dtype=torch.float)


x = torch.ones(4, requires_grad=True) * 2

print(jacobian2(f(x), x, 2))
a = hessian(f(x), x, 2)
print(a)```

The function has 4 inputs. So you get a hessian of size 4x4. That looks expected to me.