I am running my model on multiple gpus.And I have a tensor, which will be present on each Gpus, which I want to access. Now, I am looking to get hold of all these tensor, on all the gpus, and do some operation, in a synchronous fashion, and then broadcast the result on all the gpus, to be used in the next step.
For Example: the tensor we are talking is
T being present in all cuda 0-3(for 4 gpus). Now, I need to get hold of this
T tensor(which has different values at different gpu) and get some stat out of this, and then send back this stat to all gpus.
Please suggest me how this can be achieved.
The simple way I see to do this is the following:
You will have to first send all these Tensors to a common GPU. agregate the restults and compute your update. Then send the result back to each GPU.
Thanks Alan! I was wondering how do I send all the tensors to one gpu, and perform operations, which I want to perform, before the value of any of the tensor changes in any of their respective gpus.
Meaning, I don’t want the values of these tensor present on different gpus to change, before I complete my operation and send them back to respective gpu.
To send all tensors to one GPU, you’d want to use
dist.gather, which will gather all of the tensors onto one gpu (this is assuming you have one process running per gpu). If your tensor is
t, then your call would look like:
t = your_tensor(size)
if rank == 0:
# rank 0 is the node all tensors will be gathered on
gathered_tensors = [torch.zeros(size) for _ in range(WORLD_SIZE)]
dist.gather(t, gathered_tensors, dst=0)
Then, you can compute what you want and send the result back wrapped in a tensor via a
dist.scatter. To make sure the tensors don’t change on the gpu before they are gathered, ensure all nodes call into
gather at the same time when all nodes have the desired value. You could also use
torch.distributed.barrier if you need additional synchronization. Check out the docs at https://pytorch.org/docs/stable/distributed.html
Thank you Rohan for exhaustive explanation!