CUDA allocation lifetime for inputs to distributed.all_reduce

I am debugging a memory usage issue for tensor-parallel training, and I observe that runs with tensor-parallelism enabled that OOM have large amounts of “reserved but not unallocated” memory (pointing to fragmentation). If I remove the all_reduce calls in the TP region (but change nothing else), the max batch size that can fit before OOM goes up significantly (up to 2x), and there is no sign of fragmentation.

Digging in, I have noticed the following odd behavior when you perform an all_reduce as part of model execution (as you do in tensor-parallel training). Consider this toy program that interleaves matmuls with all_reduce operations:

import os

import torch
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self, dim, nlayers):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for _ in range(nlayers):
            self.layers.append(torch.nn.Linear(dim, dim, bias=False))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            torch.distributed.all_reduce(x)
        return x


def main():
    rank = int(os.environ["RANK"])
    assert rank >= 0
    torch.distributed.init_process_group(backend="nccl")
    torch.cuda.set_device(f"cuda:{rank}")

    dim = 8192
    nlayers = 8
    batch = 4096
    torch.cuda.memory._record_memory_history()

    model = Model(dim, nlayers)
    model.bfloat16()
    model.cuda()
    x = torch.randn((batch, dim), dtype=torch.bfloat16, device="cuda")

    with torch.no_grad():
        for _ in range(5):
            loss = model(x)

    torch.cuda.memory._dump_snapshot(f"snapshot-ar-{rank}.pkl")


if __name__ == "__main__":
    main()

I have run this on 2x GPUS with torchrun --standalone --nproc_per_node=2 memtest.py.

Here is what the memory visualizer shows in Active Memory for that code:

In contrast, if I remove the call to torch.distributed.all_reduce (but change nothing else), the visualizer shows:

The non-all_reduce picture makes sense: you can see each of the five iterations of the model, and within each iteration you can see the allocations per-layer for each matmul.

But the run with all_reduce enabled has long-lasting allocations that live both across layers and across loops of the model. Each of those long-lived allocations are the inputs to the all_reduce call. It makes sense that they should live a little longer (since all_reduce occurs in place), but that they persist across model iterations is perplexing.

My goal is to make the top picture look like the bottom one, so I’d love suggestions about what to try and / or any deeper insight into what the allocator is doing here. Thanks.

Bumping this. @ptrblck for the beg for pointers to internals / things I could look at to start debugging.

-Carl

Hi Carl!

Thanks for reporting the issue and the great memory allocation plots.
You might be running into this issue. In this PR we exposed an env variable to avoid the stream recording via TORCH_NCCL_AVOID_RECORD_STREAMS=1. Could you use it and let me know if the memory allocation behavior changes?

Oh my gosh, I’m floored. Cannot imagine I’d ever find that issue on my own (of course I’d forgotten even that PyTorch puts NCCL kernels on its own internal stream). Setting that env var and running the same code with all_reduce enabled gives me this beautiful picture:

Ie, it’s ~identical to the no-all_reduce allocator behavior. I will confirm that this pays off as better perf in real runs, but I am optimistic.

Thanks so much, Piotr! (Love that this is a Carilli fix, tell him I say hi next time you guys cross paths!)

1 Like

record_stream calls are generally considered harmful, and there is a long post in dev-discuss describing their dangers, you can read if you are interested FSDP & CUDACachingAllocator: an outsider newb perspective - distributed - PyTorch Dev Discussions

1 Like

Just FYI: we started rolling out avoid_record_stream as the default behavior for synchronous comm ops.
Currently covering all_gather, reduce_scatter, scatter and broadcast. See these two PRs:

We plan to cover all collectives by 2.2.0 release.

1 Like

@ngimel – thanks for the pointer, I learned a lot reading that.

@kwen2501 – good to know, sad that all_reduce missed the first rollout, as I would have been able to blindly go about my life knowing nothing of this :slight_smile:

Follow-up question @ptrblck : how do these settings interact with cuda graph capture? Is it largely the same, except that the allocations are going to be coming from the graph-private memory pool? If that’s the case, then there should also be benefits to setting the AVOID_RECORD_STREAMS flag for graph-capture, as it will allow the allocator to re-use the same buffers for collective input during capture (and therefore during replay) as opposed to (likely) needing a bunch of distinct buffers, depending on the exact timing characteristics of when allocations occur.