DistributedDataParallel didn't sync param gradients across ranks in a process group?

The background is that I’m trying to parallelize a serial of module forwards in each worker in distributed data parallel training. Since autograd does not support crossing process boundaries, I came across an idea that I could spawn sub-workers in each worker, initialize new process groups, and use queues to feed input data and gradients from outside workers to sub-workers for forward and backward within. Since the set of modules I’m trying to do parallel forward are the same, I assume that the same module parameters across all sub-workers with the same rank (the rank in sub-worker) should have the same gradients after backward. However, in my experiment below, I observed that the gradients vary across sub-workers with the same rank in the same process group. Did I do something wrong?

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import torch.nn as nn
from torch.nn.init import uniform_
from loguru import logger
import os
import torch
import time

world_size = 2
sub_world_size = 4

def init_fn(module):
    if isinstance(module, nn.Linear):

def main_worker(rank):
    logger.info(f"enter main worker: {rank}")
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    out_queues = [Queue() for _ in range(sub_world_size)]
    grad_queues = [Queue() for _ in range(sub_world_size)]
    model = nn.Linear(1, 1, bias=False).apply(init_fn)
    in_tensors = [torch.tensor([(rank + 1) * 10 + sub_rank], dtype=torch.float32) for sub_rank in range(sub_world_size)]
    logger.info(f"main worker before {rank=} model {model.weight.grad=}")
    ctx = mp.spawn(sub_worker, nprocs=sub_world_size, args=(rank, model, in_tensors, out_queues), join=False)
    logger.info(f"main worker after {rank=} model {model.weight.grad=}")

def sub_worker(rank, parent_rank, model, in_tensors, out_queues):
    dist.init_process_group("nccl", rank=parent_rank, world_size=world_size, init_method=f'tcp://{29502+rank}')
    in_tensor, out_queue = in_tensors[rank], out_queues[rank]
    out_tensor = DistributedDataParallel(model.cuda(parent_rank))(in_tensor.cuda(parent_rank))
    out_tensor_detached = out_tensor.detach()
    # out_queue.put(out_tensor_detached)
    logger.info(f"sub worker: {rank=}, {parent_rank=}, {model.weight.grad=}")

if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = ""
    os.environ["MASTER_PORT"] = "29501"
    ctx = mp.spawn(main_worker, nprocs=world_size, args=(), join=False)