Shard Tensor Across Specific Ranks

Hello,

We are two researchers trying to use tensor sharding for a tailored pipelined model.
In particular, we would like to use row-wise tensor sharding, albeit using ranks specified by the user. The problem is that initializing a device mesh requires all ranks in the main process group to go through it, and we cannot specify which ranks to exclude. Ideally, we would like to initialize a device mesh through a user-defined process group.

Is there a way to do it, or can one specify a process group inside the mesh?

For a more concrete example, please see the code below:

import os 
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh

def prepare_distributed_environment(rank=None, master_addr=None, master_port=None, world_size=None):
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = master_port
    dist.init_process_group(backend='gloo', rank=rank, world_size=world_size)

def main(rank=None, master_addr=None, master_port=None, world_size=None):
    # Suppose we have 4 ranks in total, i.e. ranks 0, 1, 2, and 3
    prepare_distributed_environment(rank, master_addr, master_port, world_size)
    # Shard first tensor across ranks [0,1], multliply it with a weight matrix, and send the output to ranks [2,3]
    rowwise_placement=[Shard(0)]
    if rank in [0,1]:
        device_mesh = init_device_mesh("cuda", (2,)) 
        
        # Define some tensors for testing
        vector = torch.tensor([0,1,2,3,4,5,6,7,8,9], dtype=torch.float32).view(10, 1)
        matrix = torch.eye(10, 10) 

        # NOTE: Currently the code gets stuck HERE because these functions require every rank to call them
        sharded_matrix = distribute_tensor(matrix, device_mesh=device_mesh, placements=rowwise_placement)
        sharded_vector = distribute_tensor(vector, device_mesh=device_mesh, placements=rowwise_placement)
        
        output = sharded_matrix @ sharded_vector
        dist.send(tensor=output, dst=[2,3].index(rank))
        
    if rank in [2,3]:
        device_mesh = init_device_mesh("cuda", (2,)) # sharded across ranks [2,3]
        vector = torch.empty(10, 1)
        sharded_input = distribute_tensor(vector, device_mesh=device_mesh, placements=rowwise_placement)
        # We update the local tensor with the received tensor
        dist.recv(tensor=sharded_input._local_tensor, src=[0,1][[2,3].index(rank)])
        
        # At this point we want to perform computations with the received tensor and another sharded matrix by only using ranks [2,3]

if __name__ == '__main__':  
    world_size = 4
    mp.spawn(main, args=('localhost', '12345', world_size), nprocs=world_size, join=True)

@cruzas2 Init a device mesh requires all ranks in the default process go through it because the underlying new_group actually requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group (reference: Distributed communication package - torch.distributed — PyTorch 2.4 documentation).

For your use case, it should be helpful to create a 2D DeviceMesh:

device_mesh = init_device_mesh("cuda", (2, 2), mesh_dim_names=(dim0, dim1)) 

You can then slice out device_mesh[“dim1”]. This gives you the submesh of [0, 1] on rank 0 and 1, and submesh of [2, 3] on rank 2 and 3. See reference here: pytorch/torch/distributed/device_mesh.py at main · pytorch/pytorch · GitHub

Doc pointer: Getting Started with DeviceMesh — PyTorch Tutorials 2.4.0+cu121 documentation

1 Like

Hi @irisz,

Thanks a lot! Your answer guided us in the right direction. For note, we are using PyTorch version 2.2.2+cu118. We noticed that our issue was also that the first call to distribute_tensor() requires all processes in the main group to go through it. Thus, we are now creating a dummy tensor beforehand. Below is our updated code.

def main(rank=None, master_addr=None, master_port=None, world_size=None):
    # Suppose we have 4 ranks in total, i.e. ranks 0, 1, 2, and 3
    prepare_distributed_environment(rank, master_addr, master_port, world_size)
    # Shard first tensor across ranks [0,1], multliply it with a weight matrix, and send the output to ranks [2,3]
    rowwise_placement=[Shard(0)]
    our_ranks = [0,1]
    next_ranks = [2,3]

    device_mesh = init_device_mesh("cuda", (2, 2), mesh_dim_names=("dim0", "dim1")) 
    # Dummy sharded tensor to go
    matrix = torch.eye(10, 10)
    sharded_matrix = distribute_tensor(matrix, device_mesh=device_mesh["dim1"], placements=rowwise_placement)
    if rank in [0,1]:
        # Define some tensors for testing
        vector = torch.tensor([0,1,2,3,4,5,6,7,8,9], dtype=torch.float32).view(10, 1)
        matrix = torch.eye(10, 10) 

        dm = device_mesh["dim1"]
        sharded_matrix = distribute_tensor(matrix, device_mesh=dm, placements=rowwise_placement)
        sharded_vector = distribute_tensor(vector, device_mesh=dm, placements=rowwise_placement)
        
        output = sharded_matrix @ sharded_vector
        output = output._local_tensor.cpu()
        
        v = [0,1,2,3,4] if rank == 0 else [5,6,7,8,9]
        solution = torch.tensor(v, dtype=torch.float32).view(5, 1)
        dist.send(tensor=output, dst=next_ranks[our_ranks.index(rank)])
    if rank in [2,3]:
        vector = torch.ones(10, 1)
        sharded_input = distribute_tensor(vector, device_mesh=device_mesh["dim1"], placements=rowwise_placement)
        sharded_matrix = distribute_tensor(2*torch.eye(10, 10), device_mesh=device_mesh["dim1"], placements=rowwise_placement)
        # We update the local tensor with the received tensor
        temp = sharded_input._local_tensor.cpu()
        dist.recv(tensor=temp, src=our_ranks[next_ranks.index(rank)])
        sharded_input._local_tensor = temp.to('cuda')
        
        output = sharded_matrix @ sharded_input
        
        v = [0,2,4,6,8] if rank == 2 else [10,12,14,16,18]
        solution = torch.tensor(v, dtype=torch.float32).view(5, 1)
        print(f"Rank {rank} received tensor - error {torch.norm(output._local_tensor.cpu() - solution)}")