Why next_functions[0][0] is called with no loss on a branch

In the following code, there is no loss on dup_x, but dup_x_acc is called? I assume the next function will be called only if the grad_fn is called. However, in this case, dup_x.grad_fn is not called if I’m correct.

      x = torch.randn(2, requires_grad=True)                                                                                                                                                                
      l = x.sum()

      dup_x = x.expand_as(x)                                                                                                                                                                                
      dup_x_acc = dup_x.grad_fn.next_functions[0][0]
      def hook(*args, **kwargs):
          logging.info('called')
      dup_x_acc.register_hook(hook)                                                                                                                                                                   
      l.backward()
1 Like

You can get better intuition of what’s happening by using torchviz:

The node dup_x.grad_fn.next_functions[0][0] is the AccumulateGrad that you see in the first figure, which corresponds exactly to the AccumulateGrad node you see in the second figure.

You can verify this with:

assert(id(dup_x.grad_fn.next_functions[0][0]) = id(l.grad_fn.next_functions[0][0]))

2 Likes