I am trying to insert a backward pre hook into a nn.Linear layer:
class Insert_Hook():
def __init__(self, module, new_grad_output):
self.new_grad_output = new_grad_output
# use prepend=True so that this is definetly the first hook being applied
self.hook = module.register_full_backward_pre_hook(self.hook_fn)
def hook_fn(self, module, grad_input, grad_output):
# simply return the previously caught grad_output
# this will replace the current grad_output
return self.new_grad_output
def close(self):
self.hook.remove()
artifical_grad = (torch.ones([3,2]),)
print(artifical_grad)
insert_hook = Insert_Hook(affected_layer,artifical_grad)
Here, affected_layer is a nn.Linear layer of the correct size.
It gives me the the error
'Linear' object has no attribute 'register_full_backward_pre_hook'
This surprises me as I thought nn.Linear subclasses nn.Module. The latter one should definetly have the method as described here Reference.
EDIT:
It will also not let me use the prepend statement:
register_full_backward_hook() got an unexpected keyword argument 'prepend'
You are linking to the docs of the current 2.0.0 release, while you are using 1.13.1 which doesn’t seem to provide this method as seen here.
Update PyTorch and you should be able to use register_full_backward_pre_hook.
I think the warning is related to the newly introduced torch.compile backend not eager mode as it’s working for me in 2.0.0:
model = nn.Linear(10, 10)
def get_hook(name):
def hook(module, grad_output):
print("module {} called with grad_output.abs().sum() {}".format(
name, grad_output[0].abs().sum()))
return hook
model.register_full_backward_pre_hook(get_hook("linear"))
x = torch.randn(1, 10)
out = model(x)
out.mean().backward()
# module linear called with grad_output.abs().sum() 1.0000001192092896
I want to confirm that the function registered with register_full_backward_pre_hook is called before the backward computation of the module. And we can modify the grad_output before calculating the grad_in. Right?
And why I can’t find the register_full_backward_pre_hook in pytorch doc v2.0.0
Thanks.
But this is the description of register_module_full_backward_hook, which is used after backward
I look for the description of register_full_backward_pre_hook, which is used before backward
I give an example to modify the grad_output successfully.
import torch
import torch.nn as nn
model = nn.Linear(10, 10)
def get_hook(name):
def hook(module, grad_output):
print("Before modify: module {} called with grad_output.abs().sum() {}".format(
name, grad_output[0].abs().sum()))
new_grad_output = grad_output[0].clone()*10
return tuple([new_grad_output])
return hook
def check_hook(name):
def hook(module, grad_input, grad_output):
print("After modify: module {} called with grad_output.abs().sum() {}".format(
name, grad_output[0].abs().sum()))
return hook
model.register_full_backward_pre_hook(get_hook("linear"))
model.register_full_backward_hook(check_hook("linear"))
x = torch.randn(1, 10)
out = model(x)
out.mean().backward()
# Output:
# Before modify: module linear called with grad_output.abs().sum() 1.0000001192092896
# After modify: module linear called with grad_output.abs().sum() 10.0
Edit: for anyone reading in the future, module backward pre hook is introduced in the 2.0.0 release but is buggy in that version, so if you want to use module backward pre hooks you should be using 2.0.1
Another workaround is to apply a hook to the output of your module instead.
import torch
import torch.nn as nn
a = torch.ones(2, requires_grad=True)
model = nn.Linear(2, 2)
def fn(grad_output):
return grad_output * 2
def fn2(module, grad_inputs, grad_output):
# The modification is still observed
print(grad_inputs, grad_output)
return (grad_inputs[0] / 2,)
# No longer using full backward pre hooks
# model.register_full_backward_pre_hook(fn)
model.register_full_backward_hook(fn2)
out = model(a)
# Instead, register a hook to the output of your module
# Unlike module hooks, register AFTER the forward runs
out.register_hook(fn)
out.sum().backward()
print(a.grad)