It should work. This is using the output from B(inputs)
to connect two graphs together. The AllReduce
communication from A and B won’t run interleavingly I think. If it hangs somehow, you could trying setting the process_group
argument of two DDP instances to different ProcessGroup
objects created using the new_group
API. This will fully decouple the communication of A and B.
It seems that gradients will be sync when loss.backward() is called.
Yes, see this page for more detail: Distributed Data Parallel — PyTorch 2.1 documentation
Q2: If loss = A(B(inputs1), B(inputs2)), will DDP work ? The forward funciton of B is called twice . btw, I don’t know what does reducer.prepare_for_backward do…
This won’t work. DDP requires forward and backward to run alternatively. The above code would run forward on B twice before one backward, which would mess up DDP internal states. However, the following would work. Suppose the local module wrapped by B is C
class Wrapper(nn.Module):
def __init__(self):
self.c = C()
def forward(inputs):
return self.c(inputs[0]), self.c(inputs[1])
B = DistributedDataParallel(Wrapper(), ...)
loss = A(B([input21, inputs2]))
This is basically using a sheer wrapper over C to process two inputs in one forward call.