Modify module gradients during backward pass

I want to modify the gradients of a module during the backward pass by for example clamping them to a range. I don’t want to just clamp the gradients of the leaves, but really during the backward pass.

I have a custom Node class that allows me to build trees that I need for another task.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Node(nn.Module):
    def __init__(self):
        super(Node, self).__init__()
        self.children = []
    def forward(self, x):
        pass
    def get_subtree(self):
        subtree = []
        self._get_subtree_recursive(subtree)
        return subtree
    def _get_subtree_recursive(self, subtree):
        subtree.append(self)
        for c in self.children:
            c._get_subtree_recursive(subtree)

class Pow(Node):
    def __init__(self, child0, child1):
        super(Pow, self).__init__()
        self.children = [child0, child1]
        self.name = "Pow"
    def forward(self, x):
        c_outputs = [c(x) for c in self.children]
        output = c_outputs[0] ** c_outputs[1]
        return output

class Var(Node):
    def __init__(self, value):
        super(Var, self).__init__()
        self.value = torch.tensor([value], requires_grad=True)
        self.name = "Var"
    def forward(self, x):
        output = self.value.repeat(x.size(0),1)
        return output

class Feature(Node):
    def __init__(self, id):
        super(Feature, self).__init__()
        self.id = id
        self.name = "Feature"
    def forward(self, x):
        output = x[:, self.id].view(-1,1)
        return output

I’ve tried the following. When I return something other than (None,) in the hook I get an error.

# Hook should clamp the gradient
def change_grad(self, grad_input, grad_output):
    print('grad_input {}'.format(grad_input))
    print('grad_output {} what it should be: {}'.format(grad_output,torch.clamp(input=grad_output[0],min=-1.0,max=1.0)))
    
    # 2 lines added so no error is thrown
    if grad_input[0] is None:
        return (None,)
    
    return (torch.clamp(input=grad_output[0],min=-1.0,max=1.0),)
hook_handles = []

c = [Var(3.), Var(2.)]
expression = Pow(Pow(c[0],Feature(0)),c[1])
expression_nodes = expression.get_subtree()

grads = {}

for idx, node in enumerate(expression_nodes):
    if not isinstance(node, Feature):
        grads[id(node)] = []
        handle = node.register_full_backward_hook(change_grad)
        handles.append(hook_handles)
        
x = torch.tensor([[3.]])
z = expression(x)
z.backward()

print(c[0].value.grad,c[1].value.grad)

for handle in hook_handles:
    handle.remove()

There’s an easier way of doing it by registering hooks on the outputs during the forward pass:

def modify_grad(grad):
    modified_grad = torch.clamp(grad,min=-1.0,max=1.0)
    return modified_grad
class Pow(Node):
    def __init__(self, child0, child1):
        super(Pow, self).__init__()
        self.children = [child0, child1]
        self.name = "Pow"
    def forward(self, x):
        c_outputs = [c(x) for c in self.children]
        output = c_outputs[0] ** c_outputs[1]
        output.register_hook(modify_grad)
        return output