Hi,
Currently, I am sharding the tensor_a with 4 processes on cpu. I just wondering after sharding the tensors. Is there anyway to get back the global tensor on each rank?
tensor_a = torch.rand(1,1,12,10)
tensor_a = distribute_tensor(tensor_a,mesh,[Shard(dim=2)])
Thanks