How to implement distributed model parallel

I’d like to implement distributed model parallel at the module level such as nn.Linear. The following snippet divides the module and distributes it to multiple GPUs, but parallel computation is not done.

class DistLinear(nn.Module):

    def __init__(self, input_size, output_size, bias=True, io_gpu=0, num_gpus=2):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.bias = bias
        self.io_gpu = io_gpu
        self.num_gpus = num_gpus

        self.setup_modules()

    def setup_modules(self):
        self.bucket_size = math.ceil(self.output_size / self.num_gpus)
        last_bucket_size = self.output_size - self.bucket_size*(self.num_gpus-1)
        self.modular = nn.ModuleList([self.get_module(self.bucket_size)
                                        for _ in range(self.num_gpus-1)])
        self.modular.append(self.get_module(last_bucket_size))
        for i in range(self.num_gpus):
            self.modular[i].cuda(i)

    def get_module(self, output_size):
        return nn.Linear(self.input_size, output_size, bias=self.bias)

    def forward(self, x): # x: batch x hidden
        xs = []
        for i in range(self.num_gpus):
            _x = x.cuda(i)
            _x = self.modular[i](_x)
            _x = _x.cuda(self.io_gpu)
            xs.append(_x)
        return torch.cat(xs, dim=-1)

Are you looking for something like DataParallel?
https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html

No, I want to do model parallel. In the above snippet I used a for loop, so I think that it is not parallel processing.
I noticed that nn.parallel.parallel_apply(modulelist, inputs) is suitable, is this correct?
When actually comparing the speed of for loop and parallel_apply, I feel that parallel_apply is a little faster.

1 Like

Hi @Ryobot, for distributed model parallel across different machines, it requires inter-node communication, right now we don’t have good support for it like DistributedDataParallel wrapper.

For the single node multiple GPU model parallel( like DataParallel wrapper), yes parallel_apply should work in your case.

1 Like

Thanks @ailzhang , my problem seems to be solved.