Problem on combining model parallelization and DDP on multi-nodes

Here’s an example showing the model parallelism + DDP setup that worked for me. TLDR: I initialize RPC, create a process group containing the “driver” ranks (one per model replica) that run the forward and backward pass, and then setup my model with DDP(Pipe(model), process_group=process_group).

Notes:

  1. The example code below won’t run out of the box because it’s missing some of the source (e.g., my Config class), but hopefully it gives you an idea of how to get it set up!
  2. My example is written assuming 2 stage model parallelism on a 2-GPU machine. To edit it for use with N-GPU machine, you’d need to change the logic for setting input_device and output_device and change the condition for running the forward pass to config.rank % num_pipeline_stages == 0.
import os
import argparse
from datetime import timedelta

import torch
import torch.nn as nn

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import rpc, init_process_group, new_group, Backend, destroy_process_group, barrier

from model import Decoder
from utils import Config, Precision

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="Path to training config file in yaml format")


class PipelineParallel(nn.Module):

    def __init__(self, pipeline, loss_fn, vocab_size, ignore_index=-1):
        super().__init__()
        self.pipeline = pipeline
        self.loss_fn = loss_fn
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index

    def __call__(self, x, y=None):
        _, t = x.shape

        loss = None
        logits = self.pipeline(x).local_value()

        if y is not None:
            loss = self.loss_fn(
                input=logits.view(-1, self.vocab_size),
                target=y.view(-1),
                ignore_index=self.ignore_index,
            )
        return logits, loss

    def make_data_parallel(self, process_group):
        self.pipeline = DDP(self.pipeline, process_group=process_group)

    @classmethod
    def from_sequential_model(
        cls,
        model,
        num_pipeline_stages,
        num_micro_batches,
        activation_ckpt_mode,
        deferred_batch_norm,
    ):
        """
        Approach copied from https://medium.com/pytorch/pytorch-data-parallel-best-practices-on-google-cloud-6c8da2be180d
        """
        local_rank = int(os.environ["LOCAL_RANK"])
        first_stage_rank = (local_rank // num_pipeline_stages) * num_pipeline_stages

        stage_idx = 0
        num_stage_params = 0
        stage_to_layers = defaultdict(list)
        sequential = model.as_sequential()
        num_model_params = sum([p.numel() for p in sequential.parameters()])
        max_params_per_stage = math.ceil(num_model_params / num_pipeline_stages)

        # partition the model into stages
        for layer in sequential:
            num_layer_params = sum([p.numel() for p in layer.parameters()])
            is_stage_full = num_stage_params + num_layer_params > max_params_per_stage
            is_last_stage = stage_idx == num_pipeline_stages - 1
            if is_stage_full and not is_last_stage:
                stage_idx += 1
                num_stage_params = num_layer_params
            else:
                num_stage_params += num_layer_params
            stage_to_layers[stage_idx].append(layer)

        # put each stage onto a GPU
        for i, stage in stage_to_layers.items():
            device = f"cuda:{first_stage_rank+i}"
            for layer in stage:
                layer.to(device=device)

        assert len(stage_to_layers) == num_pipeline_stages
        pipeline = Pipe(
            module=nn.Sequential(*[nn.Sequential(*stage_to_layers[j]) for j in range(num_pipeline_stages)]),
            chunks=num_micro_batches,
            checkpoint=activation_ckpt_mode,
            deferred_batch_norm=deferred_batch_norm,
        )
        return cls(pipeline=pipeline, loss_fn=model.loss_fn, vocab_size=model.vocab_size)


def main(config):
    config.rank = int(os.environ['RANK'])
    config.local_rank = int(os.environ['LOCAL_RANK'])
    config.world_size = int(os.environ['WORLD_SIZE'])
    config.master_addr = os.environ['MASTER_ADDR']
    config.master_port = os.environ['MASTER_PORT']

    init_process_group(
        init_method='tcp://' + str(config.master_addr) + ':' + str(config.master_port),
        backend=Backend.GLOO, rank=config.rank, world_size=config.world_size
    )
    rpc.init_rpc(
        "worker:" + str(config.rank),
        rank=config.rank,
        world_size=config.world_size,
    )
    driver_ranks = [i for i in range(config.world_size) if i % config.model.num_pipeline_stages == 0]
    process_group = new_group(ranks=driver_ranks, backend=Backend.NCCL, timeout=timedelta(days=365))

    if config.rank % 2 == 0:
        batch_size = 1
        max_len = config.model.max_len
        rand_ints = torch.randint(0, 100, (batch_size, max_len + 1))
        x, y = rand_ints[:, :-1], rand_ints[:, 1:]
        amp_context = Precision.get_amp_context(config.model.precision, config.compute.device)

        input_device = "cuda:0"
        output_device = "cuda:1"
        model = Decoder(config.model)
        model = PipelineParallel.from_sequential_model(model, 2, 2, "never", False)
        model.make_data_parallel(process_group=process_group)
        with amp_context:
            logits, loss = model(x.to(input_device), y.to(output_device))

    barrier(process_group)
    rpc.shutdown(graceful=True)
    destroy_process_group()


if __name__ == "__main__":
    args = parser.parse_args()
    config = Config.from_yaml(args.config)
    main(config)
1 Like