Hi there, I’m new to PyTorch. Just checked the DistributedDataParallel code, and have some questions related:
def _register_grad_hooks(self):
self._grad_accs = [] # need to keep them in scope
for device_idx, module in enumerate(self._module_copies):
for p in module.parameters():
if p.requires_grad:
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(p, device_idx))
self._grad_accs.append(grad_acc)
see code
Could someone explain how this works? namely, why does it create a new tensor and register hooks on the new tensor’s grad_fn.next_functions[0][0] (what is this function)?
What’s the difference between this and directly register hooks on the original tensor register_hook
Thanks.