Problem of using multiprocessing along with DistributedDataParallel

Here’s the thing. Assuming that I have 4 gpus, ‘gpu0’, ‘gpu1’, ‘gpu2’ and ‘gpu3’. I need to train one network(net1) but the input to the net1 is provided by another network(net2). I want to use all my gpus to do the training.

So first , I use two processes in my program, process1 for performing the forward step of net2 to get the inputs of net1, then using torch.distributed.broadcast to send the input to process2. Process2 receives the input sent by process1 then perform the forward and backward step for training net1.

In process1, I use torch.nn.parallel.DistributedDataParallel to put net1 on ‘gpu0’ and ‘gpu1’ .
In process2, I use torch.nn.parallel.DistributedDataParallel to put net2 on ‘gpu2’ and ‘gpu3’.

Here is how I implement a toy model:

from torch.multiprocessing import Process
from import Dataset, DataLoader
import numpy as np
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

class Random_Dataset(Dataset):

    def __init__(self, num=1200, dim=300):
        self.num = num
        self.dim = dim = torch.rand(num, dim)

    def __len__(self):
        return self.num

    def __getitem__(self, idx):

class Model(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers):
        super(Model, self).__init__()
        self.rnn = nn.LSTM(
                input_size=input_size, hidden_size=hidden_size,
                num_layers=num_layers, batch_first=True
    def forward(self, x, lengths):
        total_length = x.size(1)
        x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
        outputs, (h_n, c_n) = self.rnn(x)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(
                outputs, batch_first=True, total_length=total_length
        return outputs

class Generator(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size):
        super(Generator, self).__init__()
        padding = int((kernel_size - 1) / 2)
        self.conv = nn.Conv1d(
                in_channels, out_channels, kernel_size,

    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        x = self.conv(x)
        x = F.sigmoid(x)
        return x

def init_process(rank, world_size, backend, func, params):
    os.environ['MASTER_ADDR'] = ''
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
    group = dist.new_group([0, 1])
    func(*params, group)

def generate_data(dataloader, num_epochs, rank, device_ids, network_params,
                  use_cuda, group):
    generator = Generator(*network_params)
    if use_cuda:
        generator = generator.cuda()
    print("Network initialized!")
    generator = nn.parallel.distributed.DistributedDataParallel(
            generator, device_ids=device_ids
    for epoch in range(num_epochs):
        for i_batch, batch in enumerate(dataloader):
            batch = batch.cuda()
            rnn_input = generator(batch)
            dist.broadcast(rnn_input, rank, group)
            print("epoch:{}, batch_num:{}, broadcast finished!".format(
                epoch, i_batch

def run_rnn(num_batchs, num_epochs, src_rank, device_ids, network_params,
            input_size, use_cuda, group):
    rnn = Model(*network_params)
    if use_cuda:
        rnn = rnn.cuda()
    print("Network initialized!")
    rnn = nn.parallel.distributed.DistributedDataParallel(
            rnn, device_ids=device_ids
    optimizer = optim.Adam(rnn.parameters(), lr=1e-4)
    for epoch in range(num_epochs):
        for i_batch in range(num_batchs):
            rnn_input = torch.Tensor(input_size).cuda()
            dist.broadcast(rnn_input, src_rank, group)
            print("epoch:{}, batch_num:{}, receive finished!".format(
                epoch, i_batch
            batch_size = rnn_inputs.size(0)
            lengths = np.random.randint(low=3, high=30, size=(batch_size))
            lengths = -np.sort(-lengths)
            lengths = torch.from_numpy(lengths).long().cuda()
            out = rnn(rnn_input, lengths)
            out = torch.sum(out)
    rnn = rnn.cpu()'', rnn.state_dict)

def main(use_cuda=False):
    world_size = 2
    processes = []
    dataset = Random_Dataset()
    dataloader = DataLoader(
            dataset, batch_size=12, shuffle=True, num_workers=1
    num_epochs = 2
    num_batchs = 100
    generator_device_ids = [0, 1]
    rnn_device_ids = [2, 3]
    generator_params = (1, 50, 5)
    rnn_params = (300, 300, 3)

    p1 = Process(
            args=(0, world_size, 'nccl', generate_data,
                  (dataloader, num_epochs, 0, generator_device_ids,
                   generator_params, use_cuda)

    p2 = Process(
            args=(1, world_size, 'gloo', run_rnn, 
                (num_batchs, num_epochs, 0,
                 rnn_device_ids, rnn_params, torch.Size((12, 50, 300)),

    for p in processes:

if __name__ == '__main__':

The code hangs when call “torch.nn.parallel.DistributedDataParallel”.
Clearly two networks in above code is small enough to fit in just one gpu, but if I want the code to run like what I described at first, how should I modify my code?

First of all: if your GPUs are on one machine there is no need for DistributedDataParallel just use the plain DataParallel instead.

Second: Cuda-calls are asynchronous. If you split the network this way, the one to be trained would (probably) be slower and therefore the GPUs with the already trained net will be idle. To optimally use your GPUs, I would recommend to wrap both models in an abstract module and distribute this wrapper with DataParallel on all 4 GPUs with an increased batchsize. In this case a copy of every net will be created on each GPU to split the batch across your GPUs but the gradients will be accumulated.

What did you try so far ? I have the same problem as you .
Could you please provide any suggestions?