How does allgather handle memory allocation?

Hi,

I am noticing that doing torch.distributed.all_gather_object on cpu tensors does not fully release the RAM after the call. As a result my training job’s RAM would keep going higher as larger tensors are gathered, since some frames could have bigger data than the other frames.

I used this snippet to test out just this function itself:

    def f(i):
        scores = torch.ones(i * 10**7, dtype=torch.float32)
        fn = torch.zeros(i * 10**7, dtype=torch.bool)
        tp = torch.zeros(i * 10**7, dtype=torch.bool)
        fp = torch.zeros(i * 10**7, dtype=torch.bool)
        gathered_scores = [None] * torch.distributed.get_world_size()()
        torch.distributed.all_gather_object(gathered_scores, scores.to("cpu"))
        gathered_fn = [None] * torch.distributed.get_world_size()()
        torch.distributed.all_gather_object(gathered_fn, fn.to("cpu"))
        gathered_tp = [None] * torch.distributed.get_world_size()()
        torch.distributed.all_gather_object(gathered_tp, tp.to("cpu"))
        gathered_fp = [None] * torch.distributed.get_world_size()()
        torch.distributed.all_gather_object(gathered_fp, fp.to("cpu"))
f(2)
time.sleep(60)
f(3)
time.sleep(60)

and the RAM graph looks like this:

I’m not a memory expert, I wonder if it’s because, all gather uses pin_memory or something under the hood, that is holding reference to the Python objects? Or is it because I’m doing all gather wrong?

Thanks!