I want to overlap the following two operations: (1) a big matmul (2) a distributed cumulative sum, across 8 nodes.
The first way I tried to do this was:
# A, B, and C are arrays that already exist
future = irecv(previous_acc, RANK-1)
compute_matmul(B, C)
future.wait()
isend(previous_acc + A)
But this would only overlap the very first communication (from 0 → 1) with the computation, because node 0 will not send to node 1 until it is finished with its matmul.
To try to fix, this I considered a second way:
# A, B, and C are arrays that already exist
recv(previous_acc, RANK-1)
isend(previous_acc + A)
compute_matmul(B, C)
But this would mean that node 7 has to wait until all communications (0 → 1 → … → 7) are complete before it starts computing its matmul. So once again, we fail to overlap.
Is there a pattern or tool that solves this? I’ve tried splitting execution into Python threads, and even into CUDA streams, but nothing has worked so far.