Hook function with more parameters

Is there a way to register_backward_hook with a signature like

hook(module, grad_input, grad_output, parameter1, parameter2)

instead of

hook(module, grad_input, grad_output)

?

You don’t need the extra signatures. You can use up-values to store it for you. Something like

parameter1 = ...
def hook(module, grad_input, grad_output):
    return grad_input, grad_output, parameter1

You can even make it better, by slightly modifying the code in Why cant I see .grad of an intermediate variable?

You can use functools partial method

module.register_backward_hook(partial(hook, parameter1=p1, parameter2=p2))
1 Like