Implementing elegant KFAC - need clarification of hooks

So I’m trying to implement a nice version of KFAC in pytorch which to require only a single call to a function which to augments a nn.Module. I think the hooks mechanism actually can be quite an elegant solution for this, however, I need some clarification, to make sure I understand things correctly.
For example, let’s consider nn.Linear. During the backward the backward hook grad_input is a tuple of 3 variables. It is unclear to me why they are 3 and did not find any docs on what they are (one seems to be the grad with respect to the weights, and one the grad with respect to the inputs, the third most likely is the grad with respect to the broadcasted bias):

weight : torch.Size([784, 128])
bias : torch.Size([784])
d1: torch.Size([64, 784]) # most likely grad of broadcasted bias
d2: torch.Size([64, 128]) # most likely grad of inputs
d3: torch.Size([128, 784]) # most likely grad of weight.T
grad_output: torch.Size([64, 784])

Why the input to the module is not part of this and is there a way to access it? Generally, modifying gradients would require knowing the input.

Additionally, is there a way to make the Modules produce twice more gradients than then the number of their inputs, but somehow intersect and do separate things with half of them and the other half being normal gradients?

The main clarification of module backward hooks is “don’t use them”.
Sorry, but there isn’t a better explanation.
If you look at Yaroslavvb’s implementation, you see that he wrapped the actual autograd functions in order to have access to all he needs, quite likely, something like that is necessary.

Best regards

Thomas

Ah, that would be quite unfortunate than is this means there is no easy way to apply it to already existing code as it requires that you explicitly use the overridden functions.
Nevertheless, thanks at least that makes it clear what needs to be done.

Hi,

Just to ask a quick question. So, the problem with using register_backward_hook is that it doesn’t always return the correct number of gradients? And, in order to fix this (so that register_backward_hook works as intended) is to explicitly write out the forward and backward pass within a custom torch.autograd.function for all layers and matmul within my code so that the correct number of gradients are returned?

Thank you!

I think you want to move on to register_full_backward_hook on nightlies:

1 Like

Thank you for the response! So, I assume the only way to do KFAC via registering hooks is to use the nightly build? (I’m currently using 1.7 for reference)

Thank you!

You could always build your own, that’s all I know about.
But so I heard rumors that we can expect the 1.8 RC in February, it might not be too far to a release that has what you need.

1 Like

Great! I’ll look forward to it! Thanks!