How to perform data parallelism for model parallelism

I try to make data parallelism compatible with model parallelism, but I encounter RuntimeError: all tensors must be on devices[0] during this process. Below is a simplified example of my code (my torch version is 1.0.1.post2):

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        first_device = x.device
        x = self.fc1(x.to(self.fc1.weight.device))
        x = F.relu(x)
        x = self.fc2(x.to(self.fc2.weight.device))
        x = F.softmax(x).to(first_device)
        return x
    
    def model_parallel(self, start):
        self.fc1.cuda(start)
        self.fc2.cuda(start + 1)


def run(rank, device_id, world_size):
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='tcp://localhost:10000', 
        world_size=world_size, 
        rank=rank, 
    )
    model = MyModel()
    model.model_parallel(device_id)
    model = nn.parallel.DistributedDataParallel(
        module=model,
        device_ids=list(range(device_id, device_id + world_size)),
        output_device=device_id,
        broadcast_buffers=False,
    )
    model(torch.randn(1, 784).cuda(device_id))



if __name__ == "__main__":
    mp = torch.multiprocessing.get_context('spawn')
    
    world_size = 2
    model_size = 2
    procs = []
    for i in range(world_size):
        rank = i
        device_id = i * model_size
        procs.append(mp.Process(target=run, args=(rank, device_id, world_size, ), daemon=True))
        procs[i].start()
    for p in procs:
        p.join()

The full traceback is:

Process SpawnProcess-1:
Traceback (most recent call last):
  File "/home/user/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/user/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/user/nmt-research/example.py", line 34, in run
    broadcast_buffers=False,
  File "/home/user/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 217, in __init__
    self._ddp_init_helper()
  File "/home/user/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 232, in _ddp_init_helper
    self._module_copies = replicate(self.module, self.device_ids, detach=True)
  File "/home/user/lib/python3.6/site-packages/torch/nn/parallel/replicate.py", line 13, in replicate
    param_copies = Broadcast.apply(devices, *params)
  File "/home/user/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
    outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  File "/home/user/lib/python3.6/site-packages/torch/cuda/comm.py", line 40, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: all tensors must be on devices[0]
Process SpawnProcess-2:
Traceback (most recent call last):
  File "/home/user/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/user/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/user/nmt-research/example.py", line 34, in run
    broadcast_buffers=False,
  File "/home/user/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 217, in __init__
    self._ddp_init_helper()
  File "/home/user/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 232, in _ddp_init_helper
    self._module_copies = replicate(self.module, self.device_ids, detach=True)
  File "/home/user/lib/python3.6/site-packages/torch/nn/parallel/replicate.py", line 13, in replicate
    param_copies = Broadcast.apply(devices, *params)
  File "/home/user/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 21, in forward
    outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  File "/home/user/lib/python3.6/site-packages/torch/cuda/comm.py", line 40, in broadcast_coalesced
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: all tensors must be on devices[0]

I want to know how to perform data parallelism together with model parallelism correctly. Thanks in advance!

There are assumptions in torch.nn.parallel.DistributedDataParallel today that prevent you from doing this unfortunately. We’re working on some changes to DDP to make this possible. Stay tuned.

cc @mrshenli

Added a link to this post in this GitHub issue so we don’t forget about it.

Thanks, pietern! Hope this upgrade will be available soon!

This should be possible after #19271. Here is a tutorial @lyy1994 could you please help to verify if that works for you?