Access of gradients prior to accumulation in graph with multiple paths to output

Say I have a network in which the output of a module M is consumed by two or more successors in the graph. In classical autodiff, when the gradients for M are pushed back through the multiple successors, they are accumulated (summed). Is it possible to access the gradients prior to accumulation?

Here’s a little hack that does the trick: for each path the output of M takes, apply the identity function and call retain_grad on its output.

#!/usr/bin/env python

# coding: utf-8

from functools import partial

import torch
import torch.nn as nn


def print_tensor_grad(grad, name=None, value=None):
    print(name, 'value', value, 'grad', grad)


def print_module_grad(module, grad_input, grad_out, name=None):
    print(name, grad_input)


class Network(nn.Module):
    def __init__(self, n_in=2, n_out=2):
        super().__init__()
        self.layer1 = nn.Linear(n_in, n_out, bias=False)
        self.layer2 = nn.Linear(n_out, n_out, bias=False)
        self.layer3 = nn.Linear(n_out*2, 1, bias=False)
        self.identity = nn.LeakyReLU(negative_slope=1.0)

    def forward(self, input):
        out1 = self.layer1(input)
        out1.retain_grad()

        path1 = self.identity(out1)
        path1.retain_grad()
        path2 = self.identity(out1)
        path2.retain_grad()

        out2 = self.layer2(path1)
        input3 = torch.cat((path2, out2), dim=1)
        out3 = self.layer3(input3)

        out1.register_hook(
            partial(print_tensor_grad, name='out1', value=out1))
        path1.register_hook(
            partial(print_tensor_grad, name='path1', value=path1))
        path2.register_hook(
            partial(print_tensor_grad, name='path2', value=path2))

        return {
                'out1': out1,
                'path1': path1,
                'path2': path2,
                'y': out3
        }


if __name__ == '__main__':
    torch.manual_seed(17)
    network = Network()
    x = torch.ones(1, 2)
    out = network(x)
    out['y'].backward()
    # Verify that the gradient of the output of the first layer is the
    # same as the sum of the two paths taken by that output.
    print('out1', out['out1'].grad)
    print('path1', out['path1'].grad)
    print('path2', out['path2'].grad)
    assert torch.all(
        out['out1'].grad == out['path1'].grad + out['path2'].grad)