How can I receive the outputs from dist.all_gather_object() asynchronously?

I’m using NCCL with dist.all_gather_object() to send back the outputs from all workers to rank 0. However the code below work asynchronously, rank 0 can only reads the output until all the workers have sent their response. The script below just hangs while waiting … it doesn’t allow me to access to output. So, I would like to implement it in an asyncrhonously way, let’s say that if I uncomment the ‘time.sleep()’ for rank 2, I can be able to read in worker 0 the output for worker 1 in the meantime, while waiting for the worker 2 to finish.

if rank == 0:
        output  = {'acc': 1.89, 'loss': 1.2, 'extra_metric':10.0}
        gather_objects = [output for i in range(world)] # any picklable object
        output = [None for _ in gather_objects]
        dist.all_gather_object(output, gather_objects[dist.get_rank()])
        print(f"Rank output {rank} {output}")
    else:
                
        if rank == 2:
            output  = {'acc': 2.89, 'loss': 2.2, 'extra_metric':20.0}
            gather_objects = [output for i in range(world)]
            #time.sleep(5)
        else:
            output = {'acc': 3.89, 'loss': 3.2, 'extra_metric':30.0}
            gather_objects = [output for i in range(world)]

        output = [None for _ in gather_objects]
        dist.all_gather_object(output, gather_objects[dist.get_rank()])
        print(f"Rank output {rank} {output}")

If I execute the code above, I’m getting the following once all workers have sent back their output:

Rank output 0 [{'acc': 1.89, 'loss': 1.2, 'extra_metric': 10.0}, {'acc': 3.89, 'loss': 3.2, 'extra_metric': 30.0}, {'acc': 2.89, 'loss': 2.2, 'extra_metric': 20.0}]
Rank output 2 [{'acc': 1.89, 'loss': 1.2, 'extra_metric': 10.0}, {'acc': 3.89, 'loss': 3.2, 'extra_metric': 30.0}, {'acc': 2.89, 'loss': 2.2, 'extra_metric': 20.0}]
Rank output 1 [{'acc': 1.89, 'loss': 1.2, 'extra_metric': 10.0}, {'acc': 3.89, 'loss': 3.2, 'extra_metric': 30.0}, {'acc': 2.89, 'loss': 2.2, 'extra_metric': 20.0}]

I need to mandatory send a python object (the dictionary) I cannot use tensors. This is the reason I’m using all_gather_object(), also I tried to use gather_object() instead but this one doesn’t work for NCCL.

Hi, currently there is no way to make these APIs asynchronous, but filed a feature request asking for it: [c10d] Async object-based collectives · Issue #80417 · pytorch/pytorch · GitHub

Also, as of PyTorch 1.12, NCCL now supports gather API, so gather_object is enabled. You should be able to use gather_object in nightly builds or building from source, or in 1.12 when it is released.

This is great, thank you so much @rvarm1 ! However gather_object also uses collective operations for the output list and according to the documentation stills blocking any other call in the meantime :frowning: