Tensor hooks for gradients from each path in the autograd graph

Hi all,

I’m working on a project where I need to manipulate gradients as they flow backward through each computational path in a model’s computation graph. Currently, I know that PyTorch’s .register_hook method allows you to manipulate gradients for a tensor during backpropagation, but this hook is only called after gradients from all incoming paths have been summed together.

I’d like to “hook in” before this summation step, so that I can observe or modify gradient contributions coming from each computational path (branch) individually, before they are combined at a given node in the graph.

I don’t think it’s possible to do this with PyTorch’s API, if anyone has suggestions that don’t involve modifying the C++ internals please let me know. I’m also wondering if anyone has any thoughts on if there would be demand for this kind of feature as an addition to the hook system - would it make sense as a feature request?

2 Likes

Have you tried using register_full_backward_hook?
The hook will give you access to all grad_input and grad_output entries.

2 Likes

I’m also interested in this question. I made a minimal example of what happens with register_hook:

a = torch.tensor(5., requires_grad=True)
b1 = 2 * a
b2 = 3 * a

def hook(grad):
    print(f"hook triggered with grad={grad}")

a.register_hook(hook)

torch.autograd.grad(outputs=[b1, b2], inputs=[a])

output: hook triggered with grad=5.0

If I understand correctly you’re looking for a way to have a hook that would trigger twice: once with grad=2.0 (i.e. the gradient of b1 w.r.t. a) and once with grad=3.0 (i.e. the gradient of b2 w.r.t. a).

You can do that with a post-hook on the autograd node of b1 and b2, rather than a pre-hook on a. This would look something like this:

a = torch.tensor(5., requires_grad=True)
b1 = 2 * a
b2 = 3 * a

edge1 = get_gradient_edge(b1)
edge2 = get_gradient_edge(b2)

def hook(grad_inputs: tuple[Tensor], grad_outputs: tuple[Tensor]):
    print(f"hook triggered with grad={grad_inputs}")

edge1.node.register_hook(hook)
edge2.node.register_hook(hook)

torch.autograd.grad(outputs=[b1, b2], inputs=[a])

output:

hook triggered with grad=(tensor(3.), None)
hook triggered with grad=(tensor(2.), None)

I’m not sure why this gives tuple with also a None element, but at least we’ve got the desired two calls with the right gradients.

I think the solution proposed by @ptrblck is similar to that, but in the case where the computations made from a are each within a torch.module (in which case a module post-hook, registered via module.register_full_backward_hook to each of the modules, would do the job).

I hope this helps :slight_smile:
And please correct me if I misunderstood something!

1 Like

Can you also try gradient edges?
Writing a minimal example

import torch
from torch.autograd.graph import get_gradient_edge

def capture_individual_gradients(tensor, names=None):
    """
    Capture gradient contributions from each computational path
    before they're summed.
    
    Args:
        tensor: The tensor to monitor
        names: Optional list of names for each dependent operation
    
    Returns:
        dict: Maps operation names/indices to their gradient contributions
    """
    gradients = {}
    
    # Find all operations that depend on this tensor
    dependent_ops = []
    for op in tensor.grad_fn.next_functions if tensor.grad_fn else []:
        if op[0] is not None:
            dependent_ops.append(op)
    
    # Alternative: manually track operations
    def track_operation(output, name=None):
        edge = get_gradient_edge(output)
        
        def hook(grad_inputs, grad_outputs):
            # grad_inputs[0] contains the gradient flowing through this path
            grad_key = name if name else f"op_{len(gradients)}"
            gradients[grad_key] = grad_inputs[0].clone() if grad_inputs[0] is not None else None
            
        edge.node.register_hook(hook)
        return output
    
    return gradients, track_operation

# Example usage
a = torch.tensor(5., requires_grad=True)

# Set up gradient capture
gradients, track = capture_individual_gradients(a)

# Track each operation
b1 = track(2 * a, "multiply_by_2")
b2 = track(3 * a, "multiply_by_3")
b3 = track(a ** 2, "square")

# Compute gradients
loss = b1 + b2 + b3
loss.backward()

print("Individual gradients before summation:")
for name, grad in gradients.items():
    print(f"{name}: {grad}")
print(f"\nSummed gradient at 'a': {a.grad}")
1 Like

Thanks all for the useful answers. I ended up using a modified version of @Hamza_Javaid’s approach.

2 Likes