So I’m trying to implement a nice version of KFAC in pytorch which to require only a single call to a function which to augments a nn.Module
. I think the hooks mechanism actually can be quite an elegant solution for this, however, I need some clarification, to make sure I understand things correctly.
For example, let’s consider nn.Linear
. During the backward the backward hook grad_input
is a tuple of 3 variables. It is unclear to me why they are 3 and did not find any docs on what they are (one seems to be the grad with respect to the weights, and one the grad with respect to the inputs, the third most likely is the grad with respect to the broadcasted bias):
weight : torch.Size([784, 128])
bias : torch.Size([784])
d1: torch.Size([64, 784]) # most likely grad of broadcasted bias
d2: torch.Size([64, 128]) # most likely grad of inputs
d3: torch.Size([128, 784]) # most likely grad of weight.T
grad_output: torch.Size([64, 784])
Why the input
to the module is not part of this and is there a way to access it? Generally, modifying gradients would require knowing the input.
Additionally, is there a way to make the Modules produce twice more gradients than then the number of their inputs, but somehow intersect and do separate things with half of them and the other half being normal gradients?