I want to modify the gradients of a module during the backward pass by for example clamping them to a range. I don’t want to just clamp the gradients of the leaves, but really during the backward pass.
I have a custom Node class that allows me to build trees that I need for another task.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Node(nn.Module):
def __init__(self):
super(Node, self).__init__()
self.children = []
def forward(self, x):
pass
def get_subtree(self):
subtree = []
self._get_subtree_recursive(subtree)
return subtree
def _get_subtree_recursive(self, subtree):
subtree.append(self)
for c in self.children:
c._get_subtree_recursive(subtree)
class Pow(Node):
def __init__(self, child0, child1):
super(Pow, self).__init__()
self.children = [child0, child1]
self.name = "Pow"
def forward(self, x):
c_outputs = [c(x) for c in self.children]
output = c_outputs[0] ** c_outputs[1]
return output
class Var(Node):
def __init__(self, value):
super(Var, self).__init__()
self.value = torch.tensor([value], requires_grad=True)
self.name = "Var"
def forward(self, x):
output = self.value.repeat(x.size(0),1)
return output
class Feature(Node):
def __init__(self, id):
super(Feature, self).__init__()
self.id = id
self.name = "Feature"
def forward(self, x):
output = x[:, self.id].view(-1,1)
return output
I’ve tried the following. When I return something other than (None,) in the hook I get an error.
# Hook should clamp the gradient
def change_grad(self, grad_input, grad_output):
print('grad_input {}'.format(grad_input))
print('grad_output {} what it should be: {}'.format(grad_output,torch.clamp(input=grad_output[0],min=-1.0,max=1.0)))
# 2 lines added so no error is thrown
if grad_input[0] is None:
return (None,)
return (torch.clamp(input=grad_output[0],min=-1.0,max=1.0),)
hook_handles = []
c = [Var(3.), Var(2.)]
expression = Pow(Pow(c[0],Feature(0)),c[1])
expression_nodes = expression.get_subtree()
grads = {}
for idx, node in enumerate(expression_nodes):
if not isinstance(node, Feature):
grads[id(node)] = []
handle = node.register_full_backward_hook(change_grad)
handles.append(hook_handles)
x = torch.tensor([[3.]])
z = expression(x)
z.backward()
print(c[0].value.grad,c[1].value.grad)
for handle in hook_handles:
handle.remove()