Understanding gradient calculation with backward_pre_hooks

Hello everyone.

A little time ago I already asked about how to use the full_backward_pre_hooks.

From my understanding, these can be used to manipulate the calculation of the backward pass of the affected layer.

We have three components:

So the calculation is approximately as follows:

grad_input = grad_output*grad

So to replace the incoming grad_ouput of a layer and therefore modify the following backward pass computations (in particular the grad_input), I expected using the full_backward_pre_hook would would be the solution. But runnning the following code, doesn’t confirm my intuition:

import torch
import torch.nn as nn

class Backward_Debug_Hook():
  def __init__(self, module):
    self.hook = module.register_full_backward_hook(self.hook_fn)

  def hook_fn(self, module, grad_input, grad_output):
    print('grad_output')
    print(grad_output)
    print('grad_input')
    print(grad_input)
    
  def close(self):
    self.hook.remove() 

class Insert_Hook():
  def __init__(self, module, new_grad_output=None):
    self.new_grad_output = new_grad_output
    self.hook = module.register_full_backward_pre_hook(self.hook_fn)

  def hook_fn(self, module, grad_output):
    return self.new_grad_output

  def close(self):
    self.hook.remove()

# simple model
model = nn.Sequential(
  nn.Linear(2, 2),
  nn.Sigmoid(),
  nn.Linear(2,2)
)
last_layer = model[-1]

debug_hook = Backward_Debug_Hook(last_layer) # attach debug hook
x = torch.randn(1, 2) # artificial input
out = model(x) # forward pass
print('without gradient insertion')
out.mean().backward() # backward pass


model.zero_grad()

artifical_grad = (100*torch.ones([1,2]),)
insert_hook = Insert_Hook(last_layer,artifical_grad)
out = model(x) # forward pass
print('with gradient insertion')
out.mean().backward() # backward pass

as the output is.

grad_output
(tensor([[0.5000, 0.5000]]),)
grad_input
(tensor([[ 0.3774, -0.1604]]),)
with gradient insertion
grad_output
(tensor([[100., 100.]]),)
grad_input
(tensor([[ 0.3774, -0.1604]]),)

According to my understanding from above, I expected the grad_input to be different the second time.

Where lies the mistake in my understanding?

Bump for visibility.

Ah this is probably a bug.

import torch
import torch.nn as nn

a = torch.ones(2, requires_grad=True)

model = nn.Linear(2, 2)

def fn(module, grad_output):
    return (grad_output[0] * 0,)

model.register_full_backward_pre_hook(fn, prepend=False)

out = model(a)
out.sum().backward()
print(a.grad) # should be 0, but its not

It should be fixed in a nightly in the next couple days, but for now the following fix should be applied:

diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py
index 6d5a97d4288e4..2cc7efeb5c124 100644
--- a/torch/utils/hooks.py
+++ b/torch/utils/hooks.py
@@ -223,6 +223,10 @@ def hook(_, grad_output):
                             raise RuntimeError("Backward hook for Modules where no input requires "
                                                "gradient should always return None or None for all gradients.")
                     self.grad_outputs = None
+
+                if self.grad_outputs is not None:
+                    return tuple(self.grad_outputs[i] for i in self.output_tensors_index)
+
             grad_fn.register_hook(hook)
 
         is_tuple = True

Thanks for your reply. Glad my understanding wasn’t wrong after all.

Hope this gets fixed quickly in a stable release soon.

Thanks also for the quick fix. Will try later and notfiy if its still not working. Cheers