How to combine model parallel with data parallel?

I have designed a big model, its architecture looks like this:

class BigModel(nn.Module):
    def __init__(self, encoder: nn.Module, component1: nn.Module, component2: nn.Module, component3: nn.Module):
        super(BigModel, self).__init__()
        self.encoder = nn.DataParallel(encoder, device_ids=["cuda:0", "cuda:1","cuda:2", "cuda:3"])
       self.component1 = component1
       self.component2 = component2
       self.component3 = component3
    def deploy(self):
       self.component1 ="cuda:4")
       self.component2 ="cuda:4")
       self.component3 ="cuda:5")

I have a single machine with 6 Tesla V100 GPU (32 GB)
The encoder is very big (BERT-like model) so I want to use 4 GPUs to perform encoding process and then use the rest GPUs for other works of this model.

It can work but some times the so of the outputs of the encoder are lost.

For example, I input a batch of sequences of size (16, 256) to the encoder, data parallel should split it into 4 tensor of size (4, 256) and encode them in parallel. Then gather those output and merge into a tensor of size (16, 256, 1024) .

But, sometimes I only got the output of size (12, 256, 1024), that means one split of data is lost. I can not figure out the reason of this problem…

So is there anyone can explain this problem or give a way to combine model parallel and data parallel ?

This is very weird. Are you using DataParallel instead of DistributedDataParallel? In your encoder's forward function, could you please add some prints to check if parallel_apply indeed spawned two threads and each thread is getting (4, 256) input and emitting (4, 256, 1024) output?

I do use the nn.Dataparallel instead of nn.DistributedDataParallel.

I have checked the forward function of encoder, and I found that when the program went wrong, there are only 3 encode calls instead of 4. I dont know what is going on…

you mean there is no error message at all. Could you please share a minimum repro?

I think I have found out the cause of the problem, here is my code, it is a runable demo and you need to install:


import transformers
import torch
from transformers import ElectraModel
from transformers import AdamW
from tqdm import tqdm

class Encoder(torch.nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = ElectraModel.from_pretrained("google/electra-large-discriminator")
    def forward(self, x1, x2):
        return [self.encoder(x1), self.encoder(x2)]

class BigModel(torch.nn.Module):
    def __init__(self):
        super(BigModel, self).__init__()
        self.encoder = torch.nn.DataParallel(Encoder(),device_ids=["cuda:0", "cuda:1", "cuda:2", "cuda:3"] )
        self.l1 = torch.nn.Linear(1024, 1024)
        self.l2 = torch.nn.Linear(1024, 1024)
        self.l3 = torch.nn.Linear(1024, 2)
        self.loss = torch.nn.CrossEntropyLoss()

    def deploy(self):
        self.l1 ="cuda:4")
        self.l2 ="cuda:4")
        self.l3 ="cuda:5")
        self.loss ="cuda:5")

    def forward(self, x: torch.LongTensor, x1: torch.LongTensor, y: torch.LongTensor):
        res = self.encoder(x, x1)
        x2 = self.l1(res[0][0].to("cuda:4")).to("cuda:5")
        x3 = self.l2(res[1][0].to("cuda:4")).to("cuda:5")
        x4 = self.l3((x2.unsqueeze(1) + x3.unsqueeze(0)).mean(1))
        return self.loss(x4.reshape((-1, 2)), y.reshape((-1,))).mean()

def main():
    model = BigModel().cuda()
    optimizer = AdamW(model.parameters(), lr=3e-5)
    # assume we train a token binary classificaiton task for 10 epochs, and the data set has 100 batch
    for epoch in range(10):
        for d in tqdm(range(100)):
            ids = torch.randint(0, 30000, (16, 256))
            ids1 = torch.randint(0, 30000, (6, 256)) # strange size that lead to the problem
            y = torch.randint(0, 2, (16, 256))
            ids = ids.cuda()
            ids1 = ids1.cuda()
            y ="cuda:5")
            loss = model(ids, ids1, y)


And after I start this python program, and get the error output

  0%|                                                                                                                           | 0/100 [00:00<?, ?it/s]
torch.Size([12, 256, 1024])
torch.Size([6, 256, 1024])
  0%|                                                                                                                           | 0/100 [00:09<?, ?it/s]
Traceback (most recent call last):
  File "", line 58, in <module>
  File "", line 53, in main
    loss = model(ids, ids1, y)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "", line 36, in forward
    return self.loss(x4.reshape((-1, 2)), y.reshape((-1,))).mean()
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/", line 962, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/", line 2468, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/", line 2262, in nll_loss
    .format(input.size(0), target.size(0)))
ValueError: Expected input batch_size (3072) to match target batch_size (4096).

This error happens in loss function and torch not check the the input batch size of all input tensors. And at the same time, a split of one encode result is lost.

@mrshenli Here is my demo program, is there a potential bug ?

I can reproduce this locally, and here is what happened:

DP will try to scatter the inputs to the given devices:

In this case, the given devices are [0, 1, 2, 3], and the inputs are 16X* and 6X* tensors. The scatter function traverses the list and try to scatter every tensor in the list.

In your code, it’s same as calling x.chunk(4, 0) and x1.chunk(4, 0). However, x.chunk(4, 0) will return 4 tensors and 16 can be divided by 4. But x1.chunk(4, 0) only returns 3 tensors, as the chunk algorithm there is, when not divisible, put 6 / chunks-1 in the first chunks-1 splits and the reminder in the last split. But 6 / (4-1) is divisible, as a result, the last split has nothing.

See: torch.chunk — PyTorch 1.8.0 documentation

Then the scatter tries to zip together splits from x and x1, as a result, x's last split is dropped.

Curious, what is the expected behavior here? is it [4, 2], [4, 2], [4, 2], [4, 0], or [4, 2], [4, 2], [4, 1], [4, 1]?