It seems that the callbacks registered with register_backward_hook
only receive gradients w.r.t. the module output and the inputs to the last operation (or some such thing), as detailed here.
Now what one would expect during the backward pass would be the gradient with respect to the module’s parameters, e.g. weight
and bias
, but no such thing is available.
The gradients are stored in the grad
property of the module parameters after loss.backward()
has been called, but I’d like to get them via hooks somehow, as I do all other information.
How can this be done?