How to Efficiently Gather Python Objects Across GPUs Without GPU-to-CPU-to-GPU-to-CPU Overhead in torch.distributed?

I am running a distributed inference task where each GPU predicts one frame of an image per sample. At the end of each epoch, I need to gather all the predicted frames from all GPU processes into a single tensor. Finally, I will save the gathered frames to a file.

I came across torch.distributed.all_gather_object() to achieve this, but the documentation states the following warnings:

  1. Note
    For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsiblity to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().
  2. Warning
    Calling all_gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU → CPU transfer since tensors would be pickled. Please consider using all_gather() instead.

If I use all_gather_object() as follows:

  • GPU to CPU transfer occurs.
  • Pickling and gathering happens.
  • CPU to GPU transfer occurs after gathering.

After this, to save the gathered results to a file, I will need another:

  • GPU to CPU transfer.

The final step introduces redundant operations, and transferring tensors back to the GPU after gathering increases the risk of out-of-memory (OOM) errors.

My Goals:

  1. Since my prediction results are stored in a dictionary with a key-value structure like {filename: (frame_idx, tensor)}, I need to gather arbitrary Python objects (at least dictionaries).
  2. I cannot ensure that frames distributed to each GPU process are synchronized or evenly distributed across samples. Hence, using all_gather() with tensors directly would lead to synchronization issues.
  3. I want to avoid transferring tensors back to the GPU after gathering, as this can cause additional memory overhead and potential OOM issues.

Question:

What is the best approach to gather arbitrary Python objects (e.g., dictionaries) across ALL GPU processes to ONLY ONE CPU process efficiently without incurring unnecessary GPU-to-CPU-to-GPU-to-CPU transfers?

pin:

Now I found the torch kit TensorDict. According to its documentation, it can be used to exchange data across processes in a distributed environment. I’m ready to see if it works.