Issue with computing Hessian vector products using gradients obtained via hooks in PyTorch

Hi everyone,
I’m trying to implement a method to compute Hessian vector products (HVPs) using PyTorch, specifically using gradients obtained through hooks in a custom GradCAM class. However, I’m encountering an issue where the gradients obtained via hooks seem not to propagate correctly when computing the HVPs.

Here’s a simplified version of my code:

import torch
import torch.nn as nn
import torch.nn.functional as F

# Define a simple convolutional neural network model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)  # Output size: (224 - 3 + 1) = 222
        self.conv2 = nn.Conv2d(16, 32, 3)  # Output size: (222 - 3 + 1) = 220
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Average pooling to (1, 1)
        self.fc = nn.Linear(32, 10)  # Adjusted to 32 to match the output after pooling

    def forward(self, x):
        x = F.relu(self.conv1(x))  # 224x224 -> 222x222
        x = F.relu(self.conv2(x))  # 222x222 -> 220x220
        x = self.pool(x)  # 220x220 -> 1x1
        x = x.view(x.size(0), -1)  # Flatten tensor to (batch_size, 32)
        x = self.fc(x)  # Fully connected layer
        return x

# Define GradCAM class
class GradCAM(nn.Module):
    def __init__(self, model, target_layer):
        super(GradCAM, self).__init__()
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activation = None
        # Register forward hook
        self.target_layer.register_forward_hook(self.forward_hook)

    def forward_hook(self, module, input, output):
        self.activation = output
        output.register_hook(self.backward_hook)

    def backward_hook(self, grad):
        self.gradients = grad

    def forward(self, x):
        return self.model(x)

# Instantiate model and GradCAM
model = SimpleModel()
target_layer = model.conv2
gradcam = GradCAM(model, target_layer)

# Input tensor
input_tensor = torch.randn(1, 3, 224, 224, requires_grad=True)

# Forward pass
output = gradcam(input_tensor)

# Compute loss and perform backward pass
loss = output.sum()
gradcam.model.zero_grad()
loss.backward(retain_graph=True)

# Get gradients and activation
gradients = gradcam.gradients
activation = gradcam.activation

# Compute Hessian-Vector Product

# Ensure activation has requires_grad=True
# Ensure gradients have requires_grad=True
activation.requires_grad_(True)
gradients.requires_grad_(True)

# Compute Hessian-Vector Product
hvp = torch.autograd.grad(
    outputs=gradients,
    inputs=activation,
    grad_outputs=activation,
    retain_graph=True
)

print("Hessian-Vector Product:", hvp)

When attempting to compute the Hessian vector product using torch.autograd.grad, I encounter the following error:

Traceback (most recent call last):
  File "F:\code\torch-cam\torchcam\methods\try2.py", line 71, in <module>
    hvp = torch.autograd.grad(
  File "D:\program\anaconda3\envs\cfr\lib\site-packages\torch\autograd\__init__.py", line 412, in grad
    result = _engine_run_backward(
  File "D:\program\anaconda3\envs\cfr\lib\site-packages\torch\autograd\graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

I’ve ensured that both activation and gradients have requires_grad=True, but the issue persists. How can I correctly compute the Hessian vector product using gradients obtained via hooks in PyTorch?

Any insights or suggestions would be greatly appreciated! Thanks in advance!

I found that to make torch.autograd.grad( outputs=gradients, inputs=activation, grad_outputs=activation, retain_graph=True ) work, the grad_fn of gradients should not be None, however, the grad_fn of gradients obtained by hook is None.

Hi @Wide_White,

Doesn’t this break the computational graph by resetting the grads to True?

When you in-place set the gradients to be true you destroy its gradient history, which is leading to the error of ‘unused Tensors’.

Is there any particular reason why you’re using hooks here?

If you want to do Hessian-Vector products you can either use the torch.autograd.functional library docs: torch.autograd.functional.hvp — PyTorch 2.3 documentation

Or, you can use the torch.func libary and compose torch.func.jacrev with torch.func.jvp to define your own custom hessian-vector product, docs here: torch.func API Reference — PyTorch 2.3 documentation

1 Like

Hi @AlphaBetaGamma96, thank you for your response.

I wasn’t aware that setting activation.requires_grad_(True) could break the computational graph by resetting the gradients, which indeed leads to the error of ‘unused Tensors’. However, even after removing these lines, I still encounter the error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

The reason I’m using hooks is that I’m developing a new CAM method within a large library. This library uses hooks to capture intermediate layer outputs and gradients. The torch.autograd.functional.hvp and torch.func APIs require a func parameter, but in my case, func involves feeding activations into the latter part of the neural network, which seems infeasible.

Therefore, I’m more inclined to implement Hessian-Vector products based on the hook mechanism. I’ve been struggling with this for a week and suspect it might be a bug in the hook implementation or perhaps I’m missing some parameters.

As a supplementary point, the following implementation works:

import torch
import torch.nn as nn

# Define a complex neural network with convolutional layers
class ComplexCNN(nn.Module):
    def __init__(self):
        super(ComplexCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        conv2_output = torch.relu(self.conv2(x))  # Get output from conv2 layer
        x = self.pool(conv2_output)
        x = x.view(-1, 32 * 8 * 8)  # Flatten
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x, conv2_output  # Return final output and conv2 layer output

# Initialize network, loss function, and input
model = ComplexCNN()
criterion = nn.MSELoss()
x = torch.randn(1, 3, 32, 32, requires_grad=True)  # Input
y = torch.randn(1, 1)  # Target

# Forward pass to compute loss and conv2 layer output
output, conv2_output = model(x)
loss = criterion(output, y)

# Compute gradient of loss w.r.t. conv2 layer output
grad_conv2 = torch.autograd.grad(loss, conv2_output, create_graph=True)

# Compute Hessian-vector product of conv2 layer output w.r.t. input x
hvp = torch.autograd.grad(grad_conv2, conv2_output, grad_outputs=conv2_output, retain_graph=True)

print("Hx:", hvp)

Any potential help is greatly appreciated!

Thank you again for your help!


Hi Alpha!

As a technical aside, I believe that if you have a tensor t for which
requires_grad is True, calling t.requires_grad_ (True) doesn’t really
do anything. It doesn’t “reset” any gradients nor break the computation
graph nor destroy the gradient history.

Consider:

>>> import torch
>>> torch.__version__
'2.3.1'
>>> s = torch.zeros (5, requires_grad = True)
>>> t = 2 * s
>>> t
tensor([0., 0., 0., 0., 0.], grad_fn=<MulBackward0>)
>>> t.requires_grad_ (True)
tensor([0., 0., 0., 0., 0.], grad_fn=<MulBackward0>)
>>> t.sum().backward()
>>> s.grad
tensor([2., 2., 2., 2., 2.])

(Of course, if t has requires_grad equal to False, the computation graph
is already broken, so to speak, and calling t.requires_grad_ (True) won’t
fix any damage that had already been done, but calling it, in and of itself,
won’t do any damage – the damage was already there.)

Best.

K. Frank

1 Like

I found that if I make loss.backward(retain_graph=True, create_graph = True), then the code hvp = torch.autograd.grad works.