Manually gathering tensors to avoid CUDA out of memory error

Hi, I am implementing a retrieval model with DDP 8 GPUs.
I have to encode all Wikipedia articles (5.9M) in the model and save the encoded results (Transformer output corresponds to CLS).

By using DistributedSampler, each GPU can encode the articles on DDP condition. But when I called torch.distributed.all_gather, CUDA out of memory occurred.

So I gave up using all_gather. I started manually gathering the results from each GPU by saving and loading the results.

torch.save(cand_vecs, os.path.join(exp_root, f"cand_vecs_{self.trainer.local_rank}.pt"))
dist.barrier()  # wait all cand_vecs are saved
all_rank_cand_vecs = []
for rank in range(dist.get_world_size()):
    all_rank_cand_vecs.append(torch.stack(torch.load(os.path.join(exp_root, f"cand_vecs_{rank}.pt")), dim=0))
all_rank_cand_vecs = torch.concat(all_rank_cand_vecs, dim=0)  # (num_candidates, H=768)

# I also saved the data indices to recover the original dataset order.
all_rank_example_indices = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(all_rank_example_indices, example_indices)	

# recover the original dataset order
...

Is my code correct? I am a bit afraid because the order of my manual gathering is different from the naitive all_gather. If so, I could not recover the original dataset order.
I believe that the order of returned list from all_gather corresnponds to rank0, 1, 2, 3, 4.....

Would you be able to move the results to CPU and call all_gather on the CPU results, or is that too slow for your use case? That would be a workaround to avoid CUDA OOM.