If I create two matrix A,B based on devicemesh [[0,1,2],[3,4,5],[6,7,8]] and dtensor, and I want to gather data from device_mesh’s subprocess group to the current process arbitrarily(for example only gather from ranks[1,3]), should I use dist.new_group or other methods?for example, I want to gather dA.local_tensor() from ranks=[1,4] and dB.local_tensor() from ranks=[3,4],I need some help…I’ll loop through the modifications to gather’s sub-inbound groups, they’re randomized QAQ
@spawn_threads_and_init_comms
def collect_dtensor(world_size):
rank = dist.get_rank()
x = torch.arange(0,16).reshape(4,4)
mesh = DeviceMesh("cpu", [[0,1],[2,3]],mesh_dim_names=['dp','tp'])
dx = distribute_tensor(x, mesh,[Shard(0),Shard(1)])
row_data = funcol.all_gather_tensor(dx.to_local(), gather_dim=1, group=(mesh, 1))
dist.barrier()
print(f"rank:{rank},{row_data}")
WORLD_SIZE= 4
collect_dtensor(WORLD_SIZE)