Calculating Hessian vectors for outputs with respect to the inputs

Hi. I want to calculate the second derivatives of outputs w.r.t. inputs.
I found some codes that calculate Hessian matrices, such as @apaszke 's code.

As I understand, @apaszke 's code calculates Hessian matrix for all elements in y w.r.t. all elements in x. But what I want is the second derivative of y w.r.t. to the corresponding input x. So I changed some of their code.

And I tested that code as below.

import torch.nn as nn
import torch

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x ** 3
        return x

class SigmoidNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2, 1), nn.Sigmoid())

    def forward(self, x):
        return self.net(x)

class ReLUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())

    def forward(self, x):
        return self.net(x)

def jacobian_vector(y, x, create_graph = False):
    """
        reference: https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7
    """
    jac = []
    batch_size = y.size(0)
    flat_y = y.reshape(-1)
    grad_y = torch.zeros_like(flat_y)
    for i in range(batch_size):
        grad_y[i] = 1.
        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
        grad_x = grad_x.reshape(x.shape)
        jac_vec = grad_x[i] # y[i] = f(x[i]). so only grad_x[i] is valid for y[i]
        jac.append(jac_vec)
        grad_y[i] = 0.
    return torch.stack(jac, axis = 0)

def hessian(y, x):
    first_dev = jacobian_vector(y, x, create_graph=True)
    print("first dev:\n", first_dev)
    return jacobian_vector(first_dev, x)

if __name__=="__main__":
    print("Test Hessian Function")
    torch.autograd.set_detect_anomaly(True)

    x = torch.Tensor([[1,1], [3,3], [2,2]]).requires_grad_(True)
    net = SimpleNet()
    # net = SigmoidNet() # bug

    y = net(x)

    print("x:\n",x)
    print("y:\n",y)

    h = hessian(y, x)

    print("hessian:\n",h)

I have two questions about this code.

  1. When I use SigmoidNet(), the code shows an error that I modified a variable in place. But I cannot find where this modification is happening.

This is the error message with torch.autograd.set_detect_anomaly(True)

[W python_anomaly_mode.cpp:60] Warning: Error detected in SigmoidBackwardBackward. Traceback of forward call that caused the error:
  File "hessian_test.py", line 63, in <module>
    h = hessian(y, x)
  File "hessian_test.py", line 46, in hessian
    first_dev = jacobian_vector(y, x, create_graph=True)
  File "hessian_test.py", line 38, in jacobian_vector
    grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
  File "(deleted)/python3.6/site-packages/torch/autograd/__init__.py", line 192, in grad
    inputs, allow_unused)
 (function print_stack)
Traceback (most recent call last):
  File "hessian_test.py", line 63, in <module>
    h = hessian(y, x)
  File "hessian_test.py", line 48, in hessian
    return jacobian_vector(first_dev, x)
  File "hessian_test.py", line 38, in jacobian_vector
    grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
  File "(deleted)/python3.6/site-packages/torch/autograd/__init__.py", line 192, in grad
    inputs, allow_unused)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 1]] is at version 6; expected version 5 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
  1. Its result is weird. When I use SimpleNet() , the expected answer is
first dev: # this should be 3x**2
 tensor([[3., 3.],
        [27., 27.],
        [12., 12.]], grad_fn=<StackBackward>)
hessian: # this should be 6x
 tensor([[6., 6.],
        [18., 18.],
        [12., 12.]])

However, the actual output is

Test Hessian Function
x:
 tensor([[1., 1.],
        [3., 3.],
        [2., 2.]], requires_grad=True)
y:
 tensor([[ 1.,  1.],
        [27., 27.],
        [ 8.,  8.]], grad_fn=<PowBackward0>)
first dev:
 tensor([[3., 0.],
        [0., 0.],
        [0., 0.]], grad_fn=<StackBackward>)
hessian:
 tensor([[6., 0.],
        [0., 0.],
        [0., 0.]])

I guess I don’t have a clear understanding of torch.autograd.grad.
What am I doing wrong?
Moreover, I don’t understand what is grad_output of autograd.grad. The documentation isn’t clear to me.

I guess I found the solution.
What I wanted to do was to calculate Laplace’s equation.

  1. I found this trick to solve the error.

  2. I was misunderstanding @apaszke 's code.

I fixed my code and it seems like what I want. The code is clumsy, but at least it works.

import torch.nn as nn
import torch

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x ** 3
        return x

class SimpleSumNet(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x ** 3
        return x.sum(dim=1).unsqueeze(1)

class SigmoidNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2, 1), nn.Sigmoid())

    def forward(self, x):
        x = self.net(x)
        return x
class ReLUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())

    def forward(self, x):
        return self.net(x)

def jacobian(y, x, create_graph = False):
    """
        reference: https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7
    """
    jac = []
    flat_y = y.reshape(-1)
    grad_y = torch.zeros_like(flat_y)
    grad_y = torch.zeros_like(flat_y)
    for i in range(len(flat_y)):
        grad_y[i] = 1.
        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
        grad_x = grad_x.reshape(x.shape)
        jac.append(grad_x.reshape(x.shape))
        grad_y[i] = 0.

    return torch.stack(jac, axis = 0).reshape(y.shape + x.shape)

def laplace(y, x):
    """ Laplace's Equation
        https://en.wikipedia.org/wiki/Laplace%27s_equation
    """
    jac1 = jacobian(y, x, create_graph=True)

    batch_size = jac1.size(0)
    jvs = []
    for i in range(batch_size):
        jvs.append(jac1[i, :, i, :])

    first_dev = torch.stack(jvs, dim=0)

    jac2 = jacobian(first_dev, x)

    jvs = []
    for i in range(batch_size):
        jvs.append(jac2[i, :, :,  i, :])

    second_dev = torch.stack(jvs, dim=0)
    return second_dev.sum(dim = -1).sum(dim = -1)

if __name__=="__main__":

    x = torch.Tensor([[1,1], [3,3], [2,2]]).requires_grad_(True)
    net = SigmoidNet()

    y = net(x)**1 # trick: https://discuss.pytorch.org/t/hessian-of-output-with-respect-to-inputs/60668/10

    print("x:\n",x)
    print("y:\n",y)

    l = laplace(y, x)
    print("laplace:\n", l, l.shape)

I would like to know a more elegant way if there is any.