In the following code, there is no loss on dup_x, but dup_x_acc is called? I assume the next function will be called only if the grad_fn is called. However, in this case, dup_x.grad_fn is not called if I’m correct.
x = torch.randn(2, requires_grad=True)
l = x.sum()
dup_x = x.expand_as(x)
dup_x_acc = dup_x.grad_fn.next_functions[0][0]
def hook(*args, **kwargs):
logging.info('called')
dup_x_acc.register_hook(hook)
l.backward()