Hi,
I plan to have a multinode setup, spawn 4 processes in each node, create a CPU tensor in each process, use a dtensor to “stitch” them together (vertically), and access them uniformly. For 8 processes, each with 10 rows and 5 columns, the dtensor should have a shape of 80x5. Eventually, I want to gather arbitrary rows from the dtensor that are spread across multiple nodes/processes, something like this:
x = dtensor[[0, 20, 30, 45, 75]].
x.to(‘cuda:0’] # Further processing
For 2 node, 8 GPU setup, I am trying the following code. But I am facing the following issues:
- The dtensor, instead of 80x5, is showing 10x5 as shape.
- When I try to write x = dtensor[[0, 20, 30, 45, 75]]., it says that there is a mix between tensor and dtensor in aten module.
Could anyone help?
Code:
import os
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, DeviceMesh, Shard
def init_distributed():
# Get SLURM environment variables
rank = int(os.environ['SLURM_PROCID']) # Global rank across all tasks
local_rank = int(os.environ['SLURM_LOCALID']) # Rank within the node
world_size = int(os.environ['SLURM_NTASKS']) # Total number of tasks
# Set required environment variables for PyTorch
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost')
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
# Initialize the distributed environment
dist.init_process_group(backend='gloo', init_method='env://')
# Create a device mesh (1D mesh across all ranks)
device = torch.device('cpu') # Using CPU since we're using Gloo
mesh = DeviceMesh(device_type="cpu", mesh=[list(range(world_size))])
# Define the global tensor shape (80x5 for 8 processes, 10x5 per process)
global_shape = (world_size * 10, 5)
# Create a local tensor for each rank (10x5), filled with rank value
local_tensor = torch.full((10, 5), float(rank), dtype=torch.float32, device=device)
print(f"Rank {rank}: Created local tensor shape {local_tensor.shape} filled with {rank}")
# Create a single DTensor by sharding the global tensor across the 0th dimension
dist.barrier()
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)], run_check=False)
# Verify the DTensor shape (should be 80x5 globally)
print(f"Rank {rank}: DTensor global shape {dtensor.shape}")
print(f"Rank {rank}: DTensor local tensor shape {dtensor.to_local().shape}")
# Print the full tensor on rank 0
if rank == 0:
full_tensor = dtensor.full_tensor()
print(f"Rank {rank}: Full tensor shape {full_tensor.shape}")
print(f"Rank {rank}: Full tensor:\n{full_tensor}")
# Clean up
dist.destroy_process_group()
if __name__ == "__main__":
init_distributed()
Output:
4: Rank 4: Created local tensor shape torch.Size([10, 5]) filled with 4
4: Rank 4: DTensor global shape torch.Size([10, 5])
4: Rank 4: DTensor local tensor shape torch.Size([10, 5])
7: Rank 7: Created local tensor shape torch.Size([10, 5]) filled with 7
7: Rank 7: DTensor global shape torch.Size([10, 5])
7: Rank 7: DTensor local tensor shape torch.Size([10, 5])
6: Rank 6: Created local tensor shape torch.Size([10, 5]) filled with 6
6: Rank 6: DTensor global shape torch.Size([10, 5])
6: Rank 6: DTensor local tensor shape torch.Size([10, 5])
5: Rank 5: Created local tensor shape torch.Size([10, 5]) filled with 5
5: Rank 5: DTensor global shape torch.Size([10, 5])
5: Rank 5: DTensor local tensor shape torch.Size([10, 5])
1: Rank 1: Created local tensor shape torch.Size([10, 5]) filled with 1
1: Rank 1: DTensor global shape torch.Size([10, 5])
1: Rank 1: DTensor local tensor shape torch.Size([10, 5])
2: Rank 2: Created local tensor shape torch.Size([10, 5]) filled with 2
2: Rank 2: DTensor global shape torch.Size([10, 5])
2: Rank 2: DTensor local tensor shape torch.Size([10, 5])
3: Rank 3: Created local tensor shape torch.Size([10, 5]) filled with 3
3: Rank 3: DTensor global shape torch.Size([10, 5])
3: Rank 3: DTensor local tensor shape torch.Size([10, 5])
0: Rank 0: Created local tensor shape torch.Size([10, 5]) filled with 0
0: Rank 0: DTensor global shape torch.Size([10, 5])
0: Rank 0: DTensor local tensor shape torch.Size([10, 5])
0: Rank 0: Full tensor shape torch.Size([10, 5])
0: Rank 0: Full tensor:
0: tensor([[0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.],
0: [0., 0., 0., 0., 0.]])