Calculating diagonal sum of Hessian matrix for each sample

I’m trying to calculate the diagonal sum of the Hessian matrix w.r.t. input.
I’m working based on @apaszke 's code.
But when I tested the code, it was really slow.
I think that is because of their code loops over all elements in the y matrix, and it isn’t what I want.

What I want to do is to calculate the diagonal sum of Hessian matrix per element in a batch. Is there any faster way to do this?

Hi,

Unless the hessian is already a diagonal matrix (from the structure of your network), which is very unlikely if you’re using a CNN, then it’s not really possible to do compute this “quickly”.
You can try package like this one though that will allow you to compute them approximately faster.

2 Likes

Thank you very much! I have to try that package.

I have a follow-up question.
I tried to calculate the Hessian matrix for each row in a batch.
This is the code. This code calculates diagonal vector of the Hessian matrix w.r.t. input.
The function is a simple fully connected network.

# pytorch

import torch
import torch.nn as nn

def pth_jacobian(y, x, create_graph = False):
    """
        reference: https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7
    """
    jac = []
    flat_y = y.reshape(-1)
    grad_y = torch.zeros_like(flat_y)
    grad_y = torch.zeros_like(flat_y)
    for i in range(len(flat_y)):
        grad_y[i] = 1.
        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
        grad_x = grad_x.reshape(x.shape)
        jac.append(grad_x.reshape(x.shape))
        grad_y[i] = 0.

    return torch.stack(jac, axis = 0).reshape(y.shape + x.shape)

def pth_hessian(y, x):
    return pth_jacobian(pth_jacobian(y, x, create_graph = True), x)

def pth_hessian_with_loop(y, x):
    Hs = []
    batch_size = y.size(0)
    for i in range(batch_size):
        H = pth_jacobian(pth_jacobian(y[i], x[i], create_graph = True), x)

    return torch.stack(H, axis=0)

def pth_hessian_diag(y, x):
    H = pth_hessian(y, x)
    batch_size = y.size(0)

    diag_vec = []
    for i in range(batch_size):
        diag_vec.append(H[i, :, i, :, i, :])

    diag = torch.stack(diag_vec, dim = 0)
    x_dim = x.size(1)

    diag_vec = []
    for i in range(x_dim):
        diag_vec.append(diag[:,:,i,i])

    diag = torch.stack(diag_vec,dim=-1)
    return diag


class FC(nn.Module):
    def __init__(self, nc_in, nc_out, num_channels):
        super().__init__()

        if not isinstance(num_channels, list):
            num_channels = [num_channels]

        modules = []
        self.nc_in = nc_in
        self.nc_out = nc_out

        for nc in num_channels:
            modules.append(nn.Linear(nc_in, nc))
            modules.append(nn.Sigmoid())
            nc_in = nc

        modules.append(nn.Linear(nc_in, nc_out))
        modules.append(nn.Sigmoid())

        self.net = nn.Sequential(*modules)

    def forward(self, x):
        return self.net(x)

def main():
    batch_size = 8
    print("Batch size: ", batch_size)
    # pytorch main
    print("Run pytorch")
    x = torch.rand(batch_size, 3).requires_grad_(True).to("cuda:0")
    torch_net = FC(3, 1, [256]).to("cuda:0")
    y = torch_net(x) ** 1 # trick: https://discuss.pytorch.org/t/hessian-of-output-with-respect-to-inputs/60668/10

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    #torch_hd = pth_hessian_diag(y, x)
    torch_hd = pth_hessian_with_loop(y, x)
    end.record()

if __name__=="__main__":
    main()

But then, I encountered this error.

Batch size:  8
Run pytorch
Traceback (most recent call last):
  File "test.py", line 97, in <module>
    main()
  File "test.py", line 93, in main
    torch_hd = pth_hessian_with_loop(y, x)
  File "test.py", line 32, in pth_hessian_with_loop
    H = pth_jacobian(pth_jacobian(y[i], x[i], create_graph = True), x)
  File "test.py", line 18, in pth_jacobian
    grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
  File "(removed)/python3.6/site-packages/torch/autograd/__init__.py", line 192, in grad
    inputs, allow_unused)
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

I guess it is not possible to calculate the gradient of just part of a tensor. Am I right?

Hey,

Note that you should not do x = torch.rand(batch_size, 3).requires_grad_(True).to("cuda:0") as the x you get here is not a leaf (it won’t have its .grad field populated when you call .backward()). You should do x = torch.rand(batch_size, 3, requires_grad=True, device="cuda:0")

I guess it is not possible to calculate the gradient of just part of a tensor. Am I right?

No you cannot ask for gradient from x[i]. Because x[i] is a new Tensor, different from x and that Tensor was not used to compute your function.

1 Like

You can use the Hessian vector product from torch.autograd.functional.hvp: torch.autograd.functional.hvp — PyTorch 1.10.0 documentation to estimate the trace (diagonal sum) of the Hessian. Just let the vector be a random +/- vector. This is an unbiased estimator with variance 2(|X|_F^2 - sum_i X_{ii}^2), meaning that if the Hessian is close to diagonal, you’ll need very few samples for an accurate estimate.

1 Like