Is there a way to register_backward_hook
with a signature like
hook(module, grad_input, grad_output, parameter1, parameter2)
instead of
hook(module, grad_input, grad_output)
?
Is there a way to register_backward_hook
with a signature like
hook(module, grad_input, grad_output, parameter1, parameter2)
instead of
hook(module, grad_input, grad_output)
?
You don’t need the extra signatures. You can use up-values to store it for you. Something like
parameter1 = ...
def hook(module, grad_input, grad_output):
return grad_input, grad_output, parameter1
You can even make it better, by slightly modifying the code in Why cant I see .grad of an intermediate variable?
You can use functools partial method
module.register_backward_hook(partial(hook, parameter1=p1, parameter2=p2))