I have a list of queries for which I’m trying to get the embeddings using DDP. Currently, the way I get that is by collecting the (example_id, embedding)
on each device and then writing them to separate files with the name `{gpu_id}_output.txt’.
Is there any better way to gather the (example_id, embedding)
file with DDP? I can think of the following ways:
-
Do
torch.distributed.all_gather(example_id)
andtorch.distributed.all_gather(embedding)
and then write tooutput.txt
only on GPU 0. However, my concern is thatall_gather
is not guaranteed to gather the example_ids and embeddings in the same order. For instance, the gatheredexample_id
could be from GPUs [3,0,1,2] while gatheredembedding
could be from GPUs [0,1,2,3]. Is there a way to reconcile which example_id corresponds to which embedding when doing all_gather? -
The other approach I can think of is by creating tuple of (example_id.numpy(), embedding.numpy()) and then using
torch.distributed.all_gather_object
.
What is the recommended way to handle this?
Thanks!