Hi,
I have quite a specific problem. How could I make one node wait for gradient arriving from 2 separate computational graphs (that share the “same” node)? Following code exemplifies my problem:
# Create a tensor and set requires_grad=True to track computation
x = torch.tensor([1.0], requires_grad=True)
# Start a computational graph
y = x * 2
y.retain_grad() # to be able to check gradient later
# detach creates another computational graph
residual = y.detach()
residual.requires_grad_(True)
z = y ** 2
k = z + residual
# Compute gradients
k.backward()
print("y grad")
print(y.grad) # should be 5, but prints 4 (when using detach())
print("x grad")
print(x.grad) # should be 10, but prints 8 (when using detach())
So I would specifically want that y to accumulate the output gradient coming from both the residual and the z, and only after that continue backpropagation (to x).
I understand that for this specific example it would work if I simply would not use detach on residual = y.detach(). But this accumulation logic is needed on another more complicated scenario, where that z path actually calls a separate backward(). I would want the backpropagation to get blocked on y until it receives gradients from both computational graphs connected to it. Maybe DDP (distributed data parallel) could help?