Currently I have this code:
if rank == 0:
x = torch.arange(3.)
scatter_list = list(torch.tensor_split(x, 3))
else:
scatter_list = []
dist.scatter_object_list(output, scatter_list, src=0)
Now I have a list have 6 elements, I want to divide them, and each rank obtains 2 elements, can I use scatter to achieve it?
Thank you for your help!