Here’s an example showing the model parallelism + DDP setup that worked for me. TLDR: I initialize RPC, create a process group containing the “driver” ranks (one per model replica) that run the forward and backward pass, and then setup my model with DDP(Pipe(model), process_group=process_group).
Notes:
- The example code below won’t run out of the box because it’s missing some of the source (e.g., my Config class), but hopefully it gives you an idea of how to get it set up!
- My example is written assuming 2 stage model parallelism on a 2-GPU machine. To edit it for use with N-GPU machine, you’d need to change the logic for setting input_device and output_device and change the condition for running the forward pass to config.rank % num_pipeline_stages == 0.
import os
import argparse
from datetime import timedelta
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import rpc, init_process_group, new_group, Backend, destroy_process_group, barrier
from model import Decoder
from utils import Config, Precision
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="Path to training config file in yaml format")
class PipelineParallel(nn.Module):
def __init__(self, pipeline, loss_fn, vocab_size, ignore_index=-1):
super().__init__()
self.pipeline = pipeline
self.loss_fn = loss_fn
self.vocab_size = vocab_size
self.ignore_index = ignore_index
def __call__(self, x, y=None):
_, t = x.shape
loss = None
logits = self.pipeline(x).local_value()
if y is not None:
loss = self.loss_fn(
input=logits.view(-1, self.vocab_size),
target=y.view(-1),
ignore_index=self.ignore_index,
)
return logits, loss
def make_data_parallel(self, process_group):
self.pipeline = DDP(self.pipeline, process_group=process_group)
@classmethod
def from_sequential_model(
cls,
model,
num_pipeline_stages,
num_micro_batches,
activation_ckpt_mode,
deferred_batch_norm,
):
"""
Approach copied from https://medium.com/pytorch/pytorch-data-parallel-best-practices-on-google-cloud-6c8da2be180d
"""
local_rank = int(os.environ["LOCAL_RANK"])
first_stage_rank = (local_rank // num_pipeline_stages) * num_pipeline_stages
stage_idx = 0
num_stage_params = 0
stage_to_layers = defaultdict(list)
sequential = model.as_sequential()
num_model_params = sum([p.numel() for p in sequential.parameters()])
max_params_per_stage = math.ceil(num_model_params / num_pipeline_stages)
# partition the model into stages
for layer in sequential:
num_layer_params = sum([p.numel() for p in layer.parameters()])
is_stage_full = num_stage_params + num_layer_params > max_params_per_stage
is_last_stage = stage_idx == num_pipeline_stages - 1
if is_stage_full and not is_last_stage:
stage_idx += 1
num_stage_params = num_layer_params
else:
num_stage_params += num_layer_params
stage_to_layers[stage_idx].append(layer)
# put each stage onto a GPU
for i, stage in stage_to_layers.items():
device = f"cuda:{first_stage_rank+i}"
for layer in stage:
layer.to(device=device)
assert len(stage_to_layers) == num_pipeline_stages
pipeline = Pipe(
module=nn.Sequential(*[nn.Sequential(*stage_to_layers[j]) for j in range(num_pipeline_stages)]),
chunks=num_micro_batches,
checkpoint=activation_ckpt_mode,
deferred_batch_norm=deferred_batch_norm,
)
return cls(pipeline=pipeline, loss_fn=model.loss_fn, vocab_size=model.vocab_size)
def main(config):
config.rank = int(os.environ['RANK'])
config.local_rank = int(os.environ['LOCAL_RANK'])
config.world_size = int(os.environ['WORLD_SIZE'])
config.master_addr = os.environ['MASTER_ADDR']
config.master_port = os.environ['MASTER_PORT']
init_process_group(
init_method='tcp://' + str(config.master_addr) + ':' + str(config.master_port),
backend=Backend.GLOO, rank=config.rank, world_size=config.world_size
)
rpc.init_rpc(
"worker:" + str(config.rank),
rank=config.rank,
world_size=config.world_size,
)
driver_ranks = [i for i in range(config.world_size) if i % config.model.num_pipeline_stages == 0]
process_group = new_group(ranks=driver_ranks, backend=Backend.NCCL, timeout=timedelta(days=365))
if config.rank % 2 == 0:
batch_size = 1
max_len = config.model.max_len
rand_ints = torch.randint(0, 100, (batch_size, max_len + 1))
x, y = rand_ints[:, :-1], rand_ints[:, 1:]
amp_context = Precision.get_amp_context(config.model.precision, config.compute.device)
input_device = "cuda:0"
output_device = "cuda:1"
model = Decoder(config.model)
model = PipelineParallel.from_sequential_model(model, 2, 2, "never", False)
model.make_data_parallel(process_group=process_group)
with amp_context:
logits, loss = model(x.to(input_device), y.to(output_device))
barrier(process_group)
rpc.shutdown(graceful=True)
destroy_process_group()
if __name__ == "__main__":
args = parser.parse_args()
config = Config.from_yaml(args.config)
main(config)