Problem on combining model parallelization and DDP on multi-nodes

Hi all,

I have a problem with both large model (can not sit in one GPU memory) and large data (need more nodes to accelerate the training), and I am trying to combine the model parallelism with DDP following this tutorial. I tried the following approach, but it is not working:

  • The computing platform I am on has 25 nodes, each has 4 GPUs (16Gb memory each).
  • On each node, I spin off 1 process with 4 GPUs, and then I cut my model into 4 parts, and send each part to a different GPU on this node for model parallelization. I tested the code on one node, it works without any problems.
  • I then try to use DDP on this model which sits on 4 GPUs on one node, and replicate it across different nodes. The idea is that, the 4 GPUs on one node form a “bigger” GPU, where my DDP will replicate the structure of the model on 4 GPUs across different nodes for data parallelization.
  • The problem is my code will hang at the line of model = DDP(model), without any error/warning message. It never stops/crashes until my time on the cluster expires.

My questions are:

  1. Do you see this method correct? If yes, can you help me to address the hang problem? (I can attach my code if needed).
  2. If you think the above approach is not working (because I saw some posts suggesting to use rpc for across nodes communication when you have model parts across nodes, but I think my model parts still sit on one node, thus the above method should work).

Thanks!

1 Like

Any comments/suggestions will be welcomed. Great community, please help!

I’m hitting a similar problem!

I haven’t solved mine yet but here’s one potential lead:

  • looking into the PiPPy codebase, I noticed that they use rpc_async to call their init_data_parallel method, which makes the DDP(model) call in order to avoid a deadlock.

I’m gonna try the above fix out inspired by PiPPy. Will let you know if it works!

Here’s an example showing the model parallelism + DDP setup that worked for me. TLDR: I initialize RPC, create a process group containing the “driver” ranks (one per model replica) that run the forward and backward pass, and then setup my model with DDP(Pipe(model), process_group=process_group).

Notes:

  1. The example code below won’t run out of the box because it’s missing some of the source (e.g., my Config class), but hopefully it gives you an idea of how to get it set up!
  2. My example is written assuming 2 stage model parallelism on a 2-GPU machine. To edit it for use with N-GPU machine, you’d need to change the logic for setting input_device and output_device and change the condition for running the forward pass to config.rank % num_pipeline_stages == 0.
import os
import argparse
from datetime import timedelta

import torch
import torch.nn as nn

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import rpc, init_process_group, new_group, Backend, destroy_process_group, barrier

from model import Decoder
from utils import Config, Precision

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="Path to training config file in yaml format")


class PipelineParallel(nn.Module):

    def __init__(self, pipeline, loss_fn, vocab_size, ignore_index=-1):
        super().__init__()
        self.pipeline = pipeline
        self.loss_fn = loss_fn
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index

    def __call__(self, x, y=None):
        _, t = x.shape

        loss = None
        logits = self.pipeline(x).local_value()

        if y is not None:
            loss = self.loss_fn(
                input=logits.view(-1, self.vocab_size),
                target=y.view(-1),
                ignore_index=self.ignore_index,
            )
        return logits, loss

    def make_data_parallel(self, process_group):
        self.pipeline = DDP(self.pipeline, process_group=process_group)

    @classmethod
    def from_sequential_model(
        cls,
        model,
        num_pipeline_stages,
        num_micro_batches,
        activation_ckpt_mode,
        deferred_batch_norm,
    ):
        """
        Approach copied from https://medium.com/pytorch/pytorch-data-parallel-best-practices-on-google-cloud-6c8da2be180d
        """
        local_rank = int(os.environ["LOCAL_RANK"])
        first_stage_rank = (local_rank // num_pipeline_stages) * num_pipeline_stages

        stage_idx = 0
        num_stage_params = 0
        stage_to_layers = defaultdict(list)
        sequential = model.as_sequential()
        num_model_params = sum([p.numel() for p in sequential.parameters()])
        max_params_per_stage = math.ceil(num_model_params / num_pipeline_stages)

        # partition the model into stages
        for layer in sequential:
            num_layer_params = sum([p.numel() for p in layer.parameters()])
            is_stage_full = num_stage_params + num_layer_params > max_params_per_stage
            is_last_stage = stage_idx == num_pipeline_stages - 1
            if is_stage_full and not is_last_stage:
                stage_idx += 1
                num_stage_params = num_layer_params
            else:
                num_stage_params += num_layer_params
            stage_to_layers[stage_idx].append(layer)

        # put each stage onto a GPU
        for i, stage in stage_to_layers.items():
            device = f"cuda:{first_stage_rank+i}"
            for layer in stage:
                layer.to(device=device)

        assert len(stage_to_layers) == num_pipeline_stages
        pipeline = Pipe(
            module=nn.Sequential(*[nn.Sequential(*stage_to_layers[j]) for j in range(num_pipeline_stages)]),
            chunks=num_micro_batches,
            checkpoint=activation_ckpt_mode,
            deferred_batch_norm=deferred_batch_norm,
        )
        return cls(pipeline=pipeline, loss_fn=model.loss_fn, vocab_size=model.vocab_size)


def main(config):
    config.rank = int(os.environ['RANK'])
    config.local_rank = int(os.environ['LOCAL_RANK'])
    config.world_size = int(os.environ['WORLD_SIZE'])
    config.master_addr = os.environ['MASTER_ADDR']
    config.master_port = os.environ['MASTER_PORT']

    init_process_group(
        init_method='tcp://' + str(config.master_addr) + ':' + str(config.master_port),
        backend=Backend.GLOO, rank=config.rank, world_size=config.world_size
    )
    rpc.init_rpc(
        "worker:" + str(config.rank),
        rank=config.rank,
        world_size=config.world_size,
    )
    driver_ranks = [i for i in range(config.world_size) if i % config.model.num_pipeline_stages == 0]
    process_group = new_group(ranks=driver_ranks, backend=Backend.NCCL, timeout=timedelta(days=365))

    if config.rank % 2 == 0:
        batch_size = 1
        max_len = config.model.max_len
        rand_ints = torch.randint(0, 100, (batch_size, max_len + 1))
        x, y = rand_ints[:, :-1], rand_ints[:, 1:]
        amp_context = Precision.get_amp_context(config.model.precision, config.compute.device)

        input_device = "cuda:0"
        output_device = "cuda:1"
        model = Decoder(config.model)
        model = PipelineParallel.from_sequential_model(model, 2, 2, "never", False)
        model.make_data_parallel(process_group=process_group)
        with amp_context:
            logits, loss = model(x.to(input_device), y.to(output_device))

    barrier(process_group)
    rpc.shutdown(graceful=True)
    destroy_process_group()


if __name__ == "__main__":
    args = parser.parse_args()
    config = Config.from_yaml(args.config)
    main(config)
1 Like

Thank you so much for your help, @jasonkrone. I just see your message, and let me test it and get back to you.