How to combine data parallelism with model parallelism for multiple nodes?

For example, I have 4 nodes and every node has two gpus. I want to devide one model into four parts, every node run part of the model and use data parallelism on its two gpus.

I use hook to get the gradients and use “dist.send” to send them to other node, it’s effective for
model parallelism.

on node 1:

dist.init_process_group(backend=“gloo”, init_method=‘tcp://172.22.4.11:28456’, rank=0,
world_size=2)

outputs is the result of node 1

dist.send(tensor=outputs.to(‘cpu’), dst=1, tag=0)

rec is the gradients send from node 2

dist.recv(rec, src=1)
outputs.backward(rec.cuda())

on node 2:

dist.init_process_group(backend=“gloo”, init_method=‘tcp://172.22.4.11:28456’, rank=1,
world_size=2)

rec is the result of node 1

dist.recv(tensor=rec, src=0, tag=0)
outputs2 = net2(rec)

feta[0] is the gradients of node 2

dist.send(tensor=feat[0].to(‘cpu’), dst=0)

But when I try to combine data parallelism with model parallelism, it failed. I choose “torch.nn.parallel.DistributedDataParallel” to achieve data parallelism, but node2 can’t receive the gradients from node1.

Question:
So how to combine data parallelism with model parallelism for multiple nodes?

It might be easier to run model parallel on multiple GPUs in the same machine and distributed data parallel across machines. Checkout this section for more details.

For your above use case, you will need to create multiple process groups. Given the above configuration, 4 nodes, and 2 GPUs per node, I assume you will have 8 processes, one process per GPU. Then you can create:

  1. Two process groups of world size 4, which will be responsible for send/recv outputs/gradients across machines.
  2. One process group of world size 2 on EACH machine, which will do the distributed data parallel on the machine.

The reason you need the above setting is because DistributedDataParallel would expect all processes in the same group are training the same model in a synchronized fashion. It won’t work if you only use 2 processes in the same group of size 8.

See the new_group API.

BTW, the torch.distributed.rpc API might make send/recv outputs/grads easier for you, and it also supports distributed autograd.