Saving gradients of operators

I’m trying to save the intermediate gradients of a graph as done in table 3 in Baydin et al.. I have custom modules where I’d like to know the gradients. To do this I registered a backward hook on these modules. The gradients of the leaf-nodes are correct, but the gradients I get from the hook aren’t. How do I get the gradients that flow through operators?

Here’s what I have now:

import torch
import torch.nn as nn

class Node(nn.Module):
    def __init__(self):
        super(Node, self).__init__()
        self.children = []
        self.name = 0
    
    def forward(self, x):
        pass
    
    def get_subtree(self):
        subtree = []
        self._get_subtree_recursive(subtree)
        return subtree
    
    def _get_subtree_recursive(self, subtree):
        subtree.append(self)
        for c in self.children:
            c._get_subtree_recursive(subtree)

class Plus(Node):
    def __init__(self, child0, child1):
        super(Plus, self).__init__()
        self.children = [child0, child1]
        self.name = "Plus"
    
    def forward(self, x):
        c_outputs = [c(x) for c in self.children]
        return c_outputs[0] + c_outputs[1]

class Minus(Node):
    def __init__(self, child0, child1):
        super(Minus, self).__init__()
        self.children = [child0, child1]
        self.name = "Minus"
    
    def forward(self, x):
        c_outputs = [c(x) for c in self.children]
        return c_outputs[0] - c_outputs[1]

class Sin(Node):
    def __init__(self, child0):
        super(Sin, self).__init__()
        self.children = [child0]
        self.name = "Sin"
    
    def forward(self, x):
        c_outputs = [c(x) for c in self.children]
        return torch.sin(c_outputs[0])
    
class Ln(Node):
    def __init__(self, child0):
        super(Ln, self).__init__()
        self.children = [child0]
        self.name = "Ln"
    
    def forward(self, x):
        c_outputs = [c(x) for c in self.children]
        return torch.log(c_outputs[0])

class Multiply(Node):
    def __init__(self, child0, child1):
        super(Multiply, self).__init__()
        self.children = [child0, child1]
        self.name = "Multiply"
    
    def forward(self, x):
        c_outputs = [c(x) for c in self.children]
        return c_outputs[0] * c_outputs[1]

class Const(Node):
    def __init__(self, value):
        super(Const, self).__init__()
        self.value = nn.Parameter(torch.tensor([value], requires_grad=True))
        self.name = "Const({0:.3f})".format(self.value.item())
    
    def forward(self, x):
        return self.value.repeat(x.size(0),1)

# Dict to save grads in hook
grads = {}

def savegrad(self, grad_input, grad_output):
    if grad_input[0] is not None:
        grads[id(self)] = (grad_input[0], self.name)

# Expression with common constants
consts = [Const(2.), Const(5.)]
expression = Minus(Plus(Ln(consts[0]),Multiply(consts[0],consts[1])), Sin(consts[1]))
expression_nodes = expression.get_subtree()

# Random input
x = torch.tensor([[42.]])

for idx, c in enumerate(expression_nodes):
    c.register_backward_hook(savegrad)

z = expression(x)
z.backward()

# These values correspond with Baydin et al.
print(consts[0].value.grad)
print(consts[1].value.grad)

# These don't
for idx, c in enumerate(expression_nodes):
    if id(c) in grads.keys():
        print(grads[id(c)])

Hi Joe!

I haven’t looked at your code, but could retain_grad() work for your
use case?

A simple example:

>>> import torch
>>> torch.__version__
'1.11.0'
>>> v = torch.tensor ([1., 2., 3.], requires_grad = True)
>>> t = torch.outer (v, v)
>>> t.retain_grad()
>>> t.sum().backward()
>>> t.grad
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

Best.

K. Frank

Thank you! Both work, but the reason why the values differ from the paper is that the last gradient overwrites .grad. Some nodes are used twice so for these nodes a history of gradients needs to be retained. The values now correspond with the paper.

Sample Node and hook

class Node(nn.Module):
    def __init__(self):
        super(Node, self).__init__()
        self.children = []
        self.name = 0
        self.ret_grad = False
        self.output = []
    
    def retain_grad(self, retain_grad=True):
        self.ret_grad = retain_grad
    
    def forward(self, x):
        pass
    
    def get_subtree(self):
        subtree = []
        self._get_subtree_recursive(subtree)
        return subtree
    
    def _get_subtree_recursive(self, subtree):
        subtree.append(self)
        for c in self.children:
            c._get_subtree_recursive(subtree)

class Plus(Node):
    def __init__(self, child0, child1):
        super(Plus, self).__init__()
        self.children = [child0, child1]
        self.name = "Plus"

    
    def forward(self, x):
        c_outputs = [c(x) for c in self.children]
        output = c_outputs[0] + c_outputs[1]
        if self.ret_grad:
            output.retain_grad()
            self.output.append(output)
        return output

def savegrad(self, grad_input, grad_output):
    if grad_input[0] is not None:
        if id(self) not in grads.keys():
            grads[id(self)] = []
        grads[id(self)].append(grad_input[0])