I’m trying to save the intermediate gradients of a graph as done in table 3 in Baydin et al.. I have custom modules where I’d like to know the gradients. To do this I registered a backward hook on these modules. The gradients of the leaf-nodes are correct, but the gradients I get from the hook aren’t. How do I get the gradients that flow through operators?
Here’s what I have now:
import torch
import torch.nn as nn
class Node(nn.Module):
def __init__(self):
super(Node, self).__init__()
self.children = []
self.name = 0
def forward(self, x):
pass
def get_subtree(self):
subtree = []
self._get_subtree_recursive(subtree)
return subtree
def _get_subtree_recursive(self, subtree):
subtree.append(self)
for c in self.children:
c._get_subtree_recursive(subtree)
class Plus(Node):
def __init__(self, child0, child1):
super(Plus, self).__init__()
self.children = [child0, child1]
self.name = "Plus"
def forward(self, x):
c_outputs = [c(x) for c in self.children]
return c_outputs[0] + c_outputs[1]
class Minus(Node):
def __init__(self, child0, child1):
super(Minus, self).__init__()
self.children = [child0, child1]
self.name = "Minus"
def forward(self, x):
c_outputs = [c(x) for c in self.children]
return c_outputs[0] - c_outputs[1]
class Sin(Node):
def __init__(self, child0):
super(Sin, self).__init__()
self.children = [child0]
self.name = "Sin"
def forward(self, x):
c_outputs = [c(x) for c in self.children]
return torch.sin(c_outputs[0])
class Ln(Node):
def __init__(self, child0):
super(Ln, self).__init__()
self.children = [child0]
self.name = "Ln"
def forward(self, x):
c_outputs = [c(x) for c in self.children]
return torch.log(c_outputs[0])
class Multiply(Node):
def __init__(self, child0, child1):
super(Multiply, self).__init__()
self.children = [child0, child1]
self.name = "Multiply"
def forward(self, x):
c_outputs = [c(x) for c in self.children]
return c_outputs[0] * c_outputs[1]
class Const(Node):
def __init__(self, value):
super(Const, self).__init__()
self.value = nn.Parameter(torch.tensor([value], requires_grad=True))
self.name = "Const({0:.3f})".format(self.value.item())
def forward(self, x):
return self.value.repeat(x.size(0),1)
# Dict to save grads in hook
grads = {}
def savegrad(self, grad_input, grad_output):
if grad_input[0] is not None:
grads[id(self)] = (grad_input[0], self.name)
# Expression with common constants
consts = [Const(2.), Const(5.)]
expression = Minus(Plus(Ln(consts[0]),Multiply(consts[0],consts[1])), Sin(consts[1]))
expression_nodes = expression.get_subtree()
# Random input
x = torch.tensor([[42.]])
for idx, c in enumerate(expression_nodes):
c.register_backward_hook(savegrad)
z = expression(x)
z.backward()
# These values correspond with Baydin et al.
print(consts[0].value.grad)
print(consts[1].value.grad)
# These don't
for idx, c in enumerate(expression_nodes):
if id(c) in grads.keys():
print(grads[id(c)])