Using `autograd.functional.jacobian`/`hessian` with respect to `nn.Module` parameters

I was pretty happy to see that computation of Jacobian and Hessian matrices are now built into the new torch.autograd.functional API which avoids laboriously writing code using nested for loops and multiple calls to autograd.grad. However, I have been having a hard time understanding how to use them when the independent variables are parameters of an nn.Module. For example, I would like to be able to use hessian to compute the Hessian of a loss function w.r.t. the model’s parameters. If I don’t use an nn.Module I can successfully compute the Hessian using

import torch

x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = torch.tensor([1.0, 2.0])

def compute_z(a, b):
    output = (x * a.unsqueeze(0) ** 3 + x * b.unsqueeze(0) ** 3).sum(dim=1)
    z = ((output - y)**2).mean()
    return z

a = torch.tensor([1.0, -1.0, 2.0], requires_grad=True)
b = torch.tensor([2.0, -2.0, -1.0], requires_grad=True)
params = (a, b)

hessians = torch.autograd.functional.hessian(compute_z, params, strict=True)

param_names = ('a', 'b')
for d_name, d_hessians in zip(param_names, hessians):
    for dd_name, dd_hessian in zip(param_names, d_hessians):
        print(f'dz/d{dd_name}d{d_name} = \n{dd_hessian}\n')

which works as I would expect

dz/dada = 
tensor([[ 963.,  198.,  972.],
        [ 198., -801., 1296.],
        [ 972., 1296., 9108.]])

dz/dbda = 
tensor([[ 612.,  792.,  243.],
        [ 792., 1044.,  324.],
        [3888., 5184., 1620.]])

dz/dadb = 
tensor([[ 612.,  792., 3888.],
        [ 792., 1044., 5184.],
        [ 243.,  324., 1620.]])

dz/dbdb = 
tensor([[4068., 3168.,  972.],
        [3168., 2052., 1296.],
        [ 972., 1296., -909.]])

But when I encapsulate this within an nn.Module I can’t get it to work. Among other things I’ve tried

import torch

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.tensor([1.0, -1.0, 2.0]))
        self.b = torch.nn.Parameter(torch.tensor([2.0, -2.0, -1.0]))

    def forward(self, x):
        output = (x * self.a.unsqueeze(0) ** 3 + x * self.b.unsqueeze(0) ** 3).sum(dim=1)
        return output

x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = torch.tensor([1.0, 2.0])
net = Net()

def compute_z(*net_parameters):
    # Is there a proper way to set the parameters that works with hessian?
    for p_src, p_dst in zip(net_parameters, net.parameters()):
        p_dst.data = p_src.data

    output = net(x)
    z = ((output - y)**2).mean()
    return z

# strict=True raises exception, allowing non-strict for demonstration
hessians = torch.autograd.functional.hessian(compute_z, tuple(net.parameters()))

param_names = [n for n, _ in net.named_parameters()]
for d_name, d_hessians in zip(param_names, hessians):
    for dd_name, dd_hessian in zip(param_names, d_hessians):
        print(f'dz/d{dd_name}d{d_name} = \n{dd_hessian}\n')

but this results in all zeros

dz/dada = 
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

dz/dbda = 
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

dz/dadb = 
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

dz/dbdb = 
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

Could someone explain what I’m doing wrong here and if there is some way to utilize autograd.functional.hessian to compute the hessian w.r.t. a module’s parameters?

4 Likes

So I have a better understanding why this is failing. I think that in order to avoid side-effects when calling the autograd.functional.jacobian function (which is used within hessian), the devs decide to perform a no-op on each of the input tensors and send resulting tensor to the function. The problem with this is that all nn.Parameters are necessarily leaf nodes, so any attempt to assign or copy data into the network parameters within the function necessarily breaks the edge in the computation graph pointing back to the inputs. Furthermore, since the inputs sent in aren’t the same objects that were passed to .jacobian then you end up never including the inputs to compute_z in the computation graph.

The only semi-solution I’ve found is to simply remove/comment out this line from the pytorch implementation. This just removes the no-op but unfortunately that’s obviously not viable :frowning: My guess is this may also have other consequences, for example if there are multiple branching and merging paths in the computation graph then we may get the correct results?

I’m also working on this issue when calculating the influence function.
My current solution is to follow this workaround:

So I write this sample code (which works as a workaround):

import torch
import torch.nn as nn

_input=torch.randn(32,3)
layer = nn.Linear(3,4)
criterion=nn.CrossEntropyLoss()

weight = layer.weight

def func(weight):
    del layer.weight
    layer.weight=weight
    return criterion(layer(_input), torch.zeros(len(_input),dtype=torch.long))

torch.autograd.functional.hessian(func,weight)

The drawback of this workaround is it’s Not Safe: You have to delete the attribute weight in the func everytime, and can’t be written just before func for only onece. Otherwise, the hessian will all be 0.

I would appreciate if there is any more elegant way. And if not, maybe we need to think about a modification for torch.autograd.functional.hessian

3 Likes

I wrote a general way of doing this not safe monkey patching to substitute weights into the model. The code’s here as part of a way to call scipy.optimize.minimize as a PyTorch optimizer. It’s the same except for doing a deepcopy of the Module to avoid any state problems on subsequent iterations. It’s not good for efficiency.

I have encountered the same issue when computing hessian or vhp with respect to network parameter. Hope there’s an elegant way like .grad.

This heavily relies on the fact that you know the attributes layer.weight. This is not possible when you only have general access like layer.parameters().

@gngdb
@shengchao-lin

See NN Module functional API · Issue #49171 · pytorch/pytorch · GitHub
for a better solution, which shall be released with pytorch 1.11.

2 Likes