How the hook works?

I am reading the docs , and i am confusing about how the hook works.Here are some question in my mind

  1. Does pytorch maintain the Variable’s consumer in Variable object?
  2. At backward stage, the gradient of intermidiate Varaible is accumulated from different consumer and saved in variable.grad, and after backward, all the varaiable.grad of intermidiate Variable are reset to None ?
  3. the two types of hook are all registered at Function object, at backward stage , the calling order is
    hook registered by variable -> function.backward(..) -> hook registered by module->update varaible's grad ?

Anyone who can help me to sort it out, Thanks!

7 Likes
  1. Variables have backward references. So conceptually a list of “producers” not “consumers”. The reference chain is Variable -> Function -> … -> Function -> (root) Variable. The function objects also have references to tensors needed to compute derivatives: So for the matrix-multiply of Variables C = A @ B, C.creator (a Function) has references to A.data and B.data (tensors).

  2. The gradients of intermediate Variables are accumulated in a C++ grad_buffer object. It’s not exposed in Python except indirectly through gradient hooks. The .grad of intermediate Variables is always None.

  3. I the ordering of “hooks registered by modules” and “registered by variables” is swapped, but I’m not 100% sure. (Test it)

Some of the automatic differentiating behavior will change a bit in an upcoming PR

4 Likes

Thanks a lot @colesbury , i write the test code in the blow.

import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.nn import Parameter
from torch.autograd import Function
import math
class _Linear(Function):

    # bias is an optional argument
    def forward(self, input, weight, bias=None):
        self.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    def backward(self, grad_output):
        input, weight, bias = self.saved_tensors
        grad_input = grad_weight = grad_bias = None
        print("backwarding......")
        if self.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if self.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and self.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

def module_hook(module, grad_input, grad_out):
    print('module hook')
    print('grad_out', grad_out)

def variable_hook(grad):
    print('variable hook')
    print('grad', grad)
    return grad*.1

class Linear(nn.Module):

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        if self.bias is None:
            return _Linear()(input, self.weight)
        else:
            return _Linear()(input, self.weight, self.bias)
linear = Linear(3,1)
linear.register_backward_hook(module_hook)
value = Variable(torch.FloatTensor([[1,2,3]]), requires_grad=True)

res = linear(value)
res.register_hook(variable_hook)

res.backward()

And , the output of code above is

variable hook
grad Variable containing:
 1
[torch.FloatTensor of size 1x1]

backwarding......
module hook
grad_out (Variable containing:
 0.1000
[torch.FloatTensor of size 1x1]
,)

It seems that “hooks registered by variable” -> “backward()” -> “hooks registered by module” is right.
Looking forward to the updated version.:grin:

4 Likes

what tutorial are you referencing?