Hello everyone.
A little time ago I already asked about how to use the full_backward_pre_hooks.
From my understanding, these can be used to manipulate the calculation of the backward pass of the affected layer.
We have three components:
- grad_output (coming from the output of the model)(accesible via the full_backward_hook and full_backward_pre_hook)
- grad_input (going towards the input of the model)(accesible via the full_backward_hook , but NOT the full_backward_pre_hook)
- grad (of the current layer, so the layers gradient of the form dLayer_output/dLayer_input)(stored in the parameters .grad())
So the calculation is approximately as follows:
grad_input = grad_output*grad
So to replace the incoming grad_ouput of a layer and therefore modify the following backward pass computations (in particular the grad_input), I expected using the full_backward_pre_hook would would be the solution. But runnning the following code, doesn’t confirm my intuition:
import torch
import torch.nn as nn
class Backward_Debug_Hook():
def __init__(self, module):
self.hook = module.register_full_backward_hook(self.hook_fn)
def hook_fn(self, module, grad_input, grad_output):
print('grad_output')
print(grad_output)
print('grad_input')
print(grad_input)
def close(self):
self.hook.remove()
class Insert_Hook():
def __init__(self, module, new_grad_output=None):
self.new_grad_output = new_grad_output
self.hook = module.register_full_backward_pre_hook(self.hook_fn)
def hook_fn(self, module, grad_output):
return self.new_grad_output
def close(self):
self.hook.remove()
# simple model
model = nn.Sequential(
nn.Linear(2, 2),
nn.Sigmoid(),
nn.Linear(2,2)
)
last_layer = model[-1]
debug_hook = Backward_Debug_Hook(last_layer) # attach debug hook
x = torch.randn(1, 2) # artificial input
out = model(x) # forward pass
print('without gradient insertion')
out.mean().backward() # backward pass
model.zero_grad()
artifical_grad = (100*torch.ones([1,2]),)
insert_hook = Insert_Hook(last_layer,artifical_grad)
out = model(x) # forward pass
print('with gradient insertion')
out.mean().backward() # backward pass
as the output is.
grad_output
(tensor([[0.5000, 0.5000]]),)
grad_input
(tensor([[ 0.3774, -0.1604]]),)
with gradient insertion
grad_output
(tensor([[100., 100.]]),)
grad_input
(tensor([[ 0.3774, -0.1604]]),)
According to my understanding from above, I expected the grad_input to be different the second time.
Where lies the mistake in my understanding?