Understanding `register_forward_pre_hook` and `register_backward_hook`

Problem

I wrote following code to check my understanding about these two functions

class LinearTransformation(nn.Module):
    def __init__(self, constant=False):
        super(LinearTransformation, self).__init__()
        self.linear = nn.Linear(2, 2)
        if constant:
            nn.init.constant(self.linear.weight, 10)
        else:
            nn.init.eye_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0)
    
    def forward(self, x):
        return torch.norm(self.linear(x)) ** 2

Invoking this simple class in the following snippets return

def hook(self, input):
    print(input)
model = LinearTransformation(constant=False)

x = torch.tensor([1., 2.])

model.linear.register_forward_pre_hook(hook)
model(x)
# (tensor([1., 2.]),)
# tensor(5., grad_fn=<PowBackward0>)
def hook(self, grad_input, grad_output):
    print(grad_input)
    print(grad_output)

model = LinearTransformation(constant=True)

x = torch.tensor([1., 2.])

model.linear.register_backward_hook(hook)
model(x).backward()
# (tensor([60., 60.]), tensor([60., 60.]))
# (tensor([60., 60.]),)

My questions are

  • For register_forward_pre_hook (first snippet), why 5, which is the final output, is also returned when I just register hook for nn.Linear.
  • For register_backward_hook (second snippet), I am not sure what these tensor([60, 60]) correspond to. I could see maybe grad_output is gradient respect to output of nn.Linear. But how about another two tensor([60, 60])?

Hi,

In the first case, only the input should be given, not the output.
Are you running this in an interpreter? That would explain why the result of model(x) is printed in both your examples.

For the backward hook, as you can see in the documentation, they are not working as expected at the moment. So you should not use them :slight_smile:

Hi,

Can I use hook to add a parameter masking function to Conv2d. Specifically, I’d like to add binary mask buffer to each conv2d module, during each training step, I need to update the mask buffer and then use it to mask the weight.

Thanks!

I guess you could write the masking part in the model forward() function. Unless you would like finer control of the type of mask you use.

If you are using nightly build, we just landed a pruning utility in nn.utils. The upcoming tutorial for it can be found here.

Otherwise, you will have to emulate what is done in the pruning module yourself.

Thank you so much! I’ll try it in the next days. Is this utility support multi-GPU training and mixed precision training?

Yes, for the use cases that have been tested, pruning works well with DataParallel, DistributedDataParallel, as well as apex and torch.quantization.
Please flag anything that doesn’t work as you try out the functionality.

So then what is the best way to check gradients for each layer? I used to apply a forward hook using ‘register_forward_hook’ on each layer and was thinking of doing the same checking gradients on each layer by using ‘register_backward_hook’ for each layer.

Hi,

The best way to do it is to use register_hook() on the Tensors you want the gradients for during the forward.