Hi, I am implementing a retrieval model with DDP 8 GPUs.
I have to encode all Wikipedia articles (5.9M) in the model and save the encoded results (Transformer output corresponds to CLS).
By using DistributedSampler
, each GPU can encode the articles on DDP condition. But when I called torch.distributed.all_gather
, CUDA out of memory occurred.
So I gave up using all_gather
. I started manually gathering the results from each GPU by saving and loading the results.
torch.save(cand_vecs, os.path.join(exp_root, f"cand_vecs_{self.trainer.local_rank}.pt"))
dist.barrier() # wait all cand_vecs are saved
all_rank_cand_vecs = []
for rank in range(dist.get_world_size()):
all_rank_cand_vecs.append(torch.stack(torch.load(os.path.join(exp_root, f"cand_vecs_{rank}.pt")), dim=0))
all_rank_cand_vecs = torch.concat(all_rank_cand_vecs, dim=0) # (num_candidates, H=768)
# I also saved the data indices to recover the original dataset order.
all_rank_example_indices = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(all_rank_example_indices, example_indices)
# recover the original dataset order
...
Is my code correct? I am a bit afraid because the order of my manual gathering is different from the naitive all_gather
. If so, I could not recover the original dataset order.
I believe that the order of returned list from all_gather
corresnponds to rank0, 1, 2, 3, 4....
.