Dtensor: how to get the "global tensor" on each rank after sharding


Currently, I am sharding the tensor_a with 4 processes on cpu. I just wondering after sharding the tensors. Is there anyway to get back the global tensor on each rank?

    tensor_a = torch.rand(1,1,12,10)
    tensor_a = distribute_tensor(tensor_a,mesh,[Shard(dim=2)])


You should be able to do this. Can you try something like the following?

replica_tensor = distribute_tensor(tensor_a,mesh,[Replicate()])

1 Like