The background is that I’m trying to parallelize a serial of module forwards in each worker in distributed data parallel training. Since autograd does not support crossing process boundaries, I came across an idea that I could spawn sub-workers in each worker, initialize new process groups, and use queues to feed input data and gradients from outside workers to sub-workers for forward and backward within. Since the set of modules I’m trying to do parallel forward are the same, I assume that the same module parameters across all sub-workers with the same rank (the rank in sub-worker) should have the same gradients after backward. However, in my experiment below, I observed that the gradients vary across sub-workers with the same rank in the same process group. Did I do something wrong?
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import torch.nn as nn
from torch.nn.init import uniform_
from loguru import logger
import os
import torch
import time
world_size = 2
sub_world_size = 4
def init_fn(module):
if isinstance(module, nn.Linear):
uniform_(module.weight)
def main_worker(rank):
logger.info(f"enter main worker: {rank}")
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
out_queues = [Queue() for _ in range(sub_world_size)]
grad_queues = [Queue() for _ in range(sub_world_size)]
model = nn.Linear(1, 1, bias=False).apply(init_fn)
in_tensors = [torch.tensor([(rank + 1) * 10 + sub_rank], dtype=torch.float32) for sub_rank in range(sub_world_size)]
logger.info(f"main worker before {rank=} model {model.weight.grad=}")
ctx = mp.spawn(sub_worker, nprocs=sub_world_size, args=(rank, model, in_tensors, out_queues), join=False)
ctx.join()
time.sleep(5)
logger.info(f"main worker after {rank=} model {model.weight.grad=}")
def sub_worker(rank, parent_rank, model, in_tensors, out_queues):
dist.init_process_group("nccl", rank=parent_rank, world_size=world_size, init_method=f'tcp://127.0.0.1:{29502+rank}')
in_tensor, out_queue = in_tensors[rank], out_queues[rank]
out_tensor = DistributedDataParallel(model.cuda(parent_rank))(in_tensor.cuda(parent_rank))
out_tensor_detached = out_tensor.detach()
# out_queue.put(out_tensor_detached)
out_tensor.backward()
logger.info(f"sub worker: {rank=}, {parent_rank=}, {model.weight.grad=}")
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29501"
ctx = mp.spawn(main_worker, nprocs=world_size, args=(), join=False)
ctx.join()