Hi,
I am noticing that doing torch.distributed.all_gather_object
on cpu tensors does not fully release the RAM after the call. As a result my training job’s RAM would keep going higher as larger tensors are gathered, since some frames could have bigger data than the other frames.
I used this snippet to test out just this function itself:
def f(i):
scores = torch.ones(i * 10**7, dtype=torch.float32)
fn = torch.zeros(i * 10**7, dtype=torch.bool)
tp = torch.zeros(i * 10**7, dtype=torch.bool)
fp = torch.zeros(i * 10**7, dtype=torch.bool)
gathered_scores = [None] * torch.distributed.get_world_size()()
torch.distributed.all_gather_object(gathered_scores, scores.to("cpu"))
gathered_fn = [None] * torch.distributed.get_world_size()()
torch.distributed.all_gather_object(gathered_fn, fn.to("cpu"))
gathered_tp = [None] * torch.distributed.get_world_size()()
torch.distributed.all_gather_object(gathered_tp, tp.to("cpu"))
gathered_fp = [None] * torch.distributed.get_world_size()()
torch.distributed.all_gather_object(gathered_fp, fp.to("cpu"))
f(2)
time.sleep(60)
f(3)
time.sleep(60)
and the RAM graph looks like this:
I’m not a memory expert, I wonder if it’s because, all gather uses pin_memory or something under the hood, that is holding reference to the Python objects? Or is it because I’m doing all gather wrong?
Thanks!