Backward hook with specified gradient

I want to set up a backward hook to modify the gradient while the gradient is specified.

First, I set up a forward hook on a ReLU module of ResNet34 to get intermediate output.

def mid_out_hook(self):
    def hook(module,inp,out):
         self.mid_out = out
    return hook
def get_mid_out(self, inp, module):
    hook = module.register_forward_hook(mid_out_hook())
    self.model(inp)
    hook.remove()
    return self.mid_out
mid = get_mid_out()

I omit the wrapper class to save space. Then for the same module, I register a backward hook to modify the gradient and backward the middle output with specified gradient,

mask = torch.zero_like(mid)
def grad_hook(mask):
    def hook(module,grad_inp,grad_out):
        new_grad = torch.mul(grad_out[0], mask)
        return (new_grad,)
    return hook
module.register_full_backward_hook(grad_hook(mask))
mid.backward(gradient = torch.ones_like(mid))

However, I got errors

RuntimeError: Module backward hook for grad_input is called before the grad_output one. This happens because the gradient in your nn.Module flows to the Module’s input without passing through the Module’s output. Make sure that the output depends on the input and that the loss is computed based on the output.

A possible solution is to modify the gradient first and backward through the module input instead of output to avoid the backward hook. However, I am wondering is there other way to resolve the conflict between backward hook gradient modification and specified gradient.

Since your module only takes a single input you could just register a hook to the input tensor of the module instead. torch.Tensor.register_hook — PyTorch 2.0 documentation

Note that Tensor hooks would behave differently in a subtle way. You can read Autograd mechanics — PyTorch 2.0 documentation for more details on that.

It could work but I guess I need to register the hook for every different input, which is a little annoying.

I realize that the problem is that the middle output is obtained by the forward hook instead of by the module output. In that way, the backward hook function would not be successfully applied to the module output as it does not record the grad_fn. However, I still don’t know how to fix it at this point.

Another way is to pass the module inputs through a dummy module that returns all the inputs as-is. Then register a module full backward hook to that module.