I want to set up a backward hook to modify the gradient while the gradient is specified.
First, I set up a forward hook on a ReLU module of ResNet34 to get intermediate output.
def mid_out_hook(self):
def hook(module,inp,out):
self.mid_out = out
return hook
def get_mid_out(self, inp, module):
hook = module.register_forward_hook(mid_out_hook())
self.model(inp)
hook.remove()
return self.mid_out
mid = get_mid_out()
I omit the wrapper class to save space. Then for the same module, I register a backward hook to modify the gradient and backward the middle output with specified gradient,
mask = torch.zero_like(mid)
def grad_hook(mask):
def hook(module,grad_inp,grad_out):
new_grad = torch.mul(grad_out[0], mask)
return (new_grad,)
return hook
module.register_full_backward_hook(grad_hook(mask))
mid.backward(gradient = torch.ones_like(mid))
However, I got errors
RuntimeError: Module backward hook for grad_input is called before the grad_output one. This happens because the gradient in your nn.Module flows to the Module’s input without passing through the Module’s output. Make sure that the output depends on the input and that the loss is computed based on the output.
A possible solution is to modify the gradient first and backward through the module input instead of output to avoid the backward hook. However, I am wondering is there other way to resolve the conflict between backward hook gradient modification and specified gradient.