Multi model and multi forward in distirbuted data parallel

It should work. This is using the output from B(inputs) to connect two graphs together. The AllReduce communication from A and B won’t run interleavingly I think. If it hangs somehow, you could trying setting the process_group argument of two DDP instances to different ProcessGroup objects created using the new_group API. This will fully decouple the communication of A and B.

It seems that gradients will be sync when loss.backward() is called.

Yes, see this page for more detail: Distributed Data Parallel — PyTorch 2.1 documentation

Q2: If loss = A(B(inputs1), B(inputs2)), will DDP work ? The forward funciton of B is called twice . btw, I don’t know what does reducer.prepare_for_backward do…

This won’t work. DDP requires forward and backward to run alternatively. The above code would run forward on B twice before one backward, which would mess up DDP internal states. However, the following would work. Suppose the local module wrapped by B is C


class Wrapper(nn.Module):
    def __init__(self):
        self.c = C()

    def forward(inputs):
        return self.c(inputs[0]), self.c(inputs[1])

B = DistributedDataParallel(Wrapper(), ...)

loss = A(B([input21, inputs2]))

This is basically using a sheer wrapper over C to process two inputs in one forward call.

2 Likes