Say I have a network in which the output of a module M
is consumed by two or more successors in the graph. In classical autodiff, when the gradients for M
are pushed back through the multiple successors, they are accumulated (summed). Is it possible to access the gradients prior to accumulation?
Here’s a little hack that does the trick: for each path the output of M
takes, apply the identity function and call retain_grad
on its output.
#!/usr/bin/env python
# coding: utf-8
from functools import partial
import torch
import torch.nn as nn
def print_tensor_grad(grad, name=None, value=None):
print(name, 'value', value, 'grad', grad)
def print_module_grad(module, grad_input, grad_out, name=None):
print(name, grad_input)
class Network(nn.Module):
def __init__(self, n_in=2, n_out=2):
super().__init__()
self.layer1 = nn.Linear(n_in, n_out, bias=False)
self.layer2 = nn.Linear(n_out, n_out, bias=False)
self.layer3 = nn.Linear(n_out*2, 1, bias=False)
self.identity = nn.LeakyReLU(negative_slope=1.0)
def forward(self, input):
out1 = self.layer1(input)
out1.retain_grad()
path1 = self.identity(out1)
path1.retain_grad()
path2 = self.identity(out1)
path2.retain_grad()
out2 = self.layer2(path1)
input3 = torch.cat((path2, out2), dim=1)
out3 = self.layer3(input3)
out1.register_hook(
partial(print_tensor_grad, name='out1', value=out1))
path1.register_hook(
partial(print_tensor_grad, name='path1', value=path1))
path2.register_hook(
partial(print_tensor_grad, name='path2', value=path2))
return {
'out1': out1,
'path1': path1,
'path2': path2,
'y': out3
}
if __name__ == '__main__':
torch.manual_seed(17)
network = Network()
x = torch.ones(1, 2)
out = network(x)
out['y'].backward()
# Verify that the gradient of the output of the first layer is the
# same as the sum of the two paths taken by that output.
print('out1', out['out1'].grad)
print('path1', out['path1'].grad)
print('path2', out['path2'].grad)
assert torch.all(
out['out1'].grad == out['path1'].grad + out['path2'].grad)