How to achieve const-memory-cost for capturing cudagraph with multiple batch sizes?

what I want to achieve, is const-memory-cost for capturing multiple cudagraphs:

input_buffer = torch.empty((MAX_SZIE, DIM))
output_buffer = torch.empty((MAX_SZIE, DIM))

for bs in range(1, MAX_SZIE + 1):
    with graph_capture() # capture graph, and share memory pool
        input = input_buffer[:bs]
        output = net(input)
        output_buffer.copy_(output)

ideally, graph capture in this case should only cost the activation of max batchsize.

However, I find in practice this still costs memory:

import torch
from contextlib import contextmanager

@contextmanager
def graph_capture(pool=None, stream=None, capture_error_mode: str = "global", dump_path=None):
    g = torch.cuda.CUDAGraph()
    if dump_path is not None:
        g.enable_debug_mode()
    with torch.cuda.graph(cuda_graph=g, pool=pool, stream=stream, capture_error_mode=capture_error_mode):
        yield g
    if dump_path is not None:
        g.debug_dump(dump_path)

import ctypes

# Load the CUDA runtime library
cudart = ctypes.CDLL('libcudart.so')

# Define cudaMemcpyKind enumeration as in the CUDA API
cudaMemcpyHostToHost = 0
cudaMemcpyHostToDevice = 1
cudaMemcpyDeviceToHost = 2
cudaMemcpyDeviceToDevice = 3
cudaMemcpyDefault = 4

# Setup the prototype of the cudaMemcpyAsync function
cudaMemcpyAsync = cudart.cudaMemcpyAsync
cudaMemcpyAsync.argtypes = [
    ctypes.c_void_p,          # void* dst
    ctypes.c_void_p,          # const void* src
    ctypes.c_size_t,          # size_t count
    ctypes.c_int,             # enum cudaMemcpyKind
    ctypes.c_void_p           # cudaStream_t stream
]
cudaMemcpyAsync.restype = ctypes.c_int


MAX_BATCHSIZE = 128

# Placeholder input used for capture
static_a = torch.zeros((MAX_BATCHSIZE, 1024), device="cpu").pin_memory()
static_b = torch.zeros((MAX_BATCHSIZE, 1024), device="cpu").pin_memory()
static_output = torch.zeros((MAX_BATCHSIZE, 1024), device="cpu").pin_memory()

def compute(batchsize):
    a = static_a[:batchsize].to("cuda", non_blocking=True)
    b = static_b[:batchsize].to("cuda", non_blocking=True)
    output = (a ** 2 + b * 2)
    result = cudaMemcpyAsync(static_output.data_ptr(), output.data_ptr(), output.numel() * output.element_size(), cudaMemcpyDeviceToHost, torch.cuda.current_stream().cuda_stream)
    assert result == 0
    return static_output[:batchsize]

# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(1, MAX_BATCHSIZE + 1):
        compute(i)
torch.cuda.current_stream().wait_stream(s)

def report_memory(prefix):
    free, total = torch.cuda.mem_get_info()
    used = total - free
    print(f"{prefix}: Used: {used / 1024 / 1024} MB, Free: {free / 1024 / 1024} MB, Total: {total / 1024 / 1024} MB")

# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
report_memory("Before capture")
graphs = [0] # 0 is a placeholder for 0 batchsize
memory_pool = None
for i in range(1, MAX_BATCHSIZE + 1):
    with graph_capture(pool=memory_pool) as g:
        compute(i)
    graphs.append(g)
    memory_pool = g.pool()
    report_memory(f"After capture batchsize {i}")
# Run the graph
static_a[:2] += 1
static_b[:2] += 2
graphs[2].replay()
torch.cuda.current_stream().synchronize()
print(static_output[:2])

Every several batchsizes, it uses some more memory:

Before capture: Used: 527.375 MB, Free: 80523.25 MB, Total: 81050.625 MB
After capture batchsize 1: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 2: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 3: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 4: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 5: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 6: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 7: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 8: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 9: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 10: Used: 529.375 MB, Free: 80521.25 MB, Total: 81050.625 MB
After capture batchsize 11: Used: 531.375 MB, Free: 80519.25 MB, Total: 81050.625 MB
After capture batchsize 12: Used: 531.375 MB, Free: 80519.25 MB, Total: 81050.625 MB
After capture batchsize 13: Used: 533.375 MB, Free: 80517.25 MB, Total: 81050.625 MB
After capture batchsize 14: Used: 533.375 MB, Free: 80517.25 MB, Total: 81050.625 MB
After capture batchsize 15: Used: 533.375 MB, Free: 80517.25 MB, Total: 81050.625 MB
After capture batchsize 16: Used: 533.375 MB, Free: 80517.25 MB, Total: 81050.625 MB
After capture batchsize 17: Used: 533.375 MB, Free: 80517.25 MB, Total: 81050.625 MB

This would be expected behavior as segments allocated by the caching allocator cannot be reclaimed if the original graph capture is kept alive—and these segments have fixed size unless e.g., expandable segments (currently not supported but we’re working on it) is used. This will increase the apparent memory usage as the caching allocator then has to allocate an entirely new segment every time this happens, wasting the memory in previously allocated segments. A workaround that you might try that could help here is capturing your largest batch size first to increase the odds that subsequent captures’ tensors fit in the already allocated segments.

1 Like

I tried to reverse the batch sizes to capture, but the memory cost is the same here.

I already used the memory pool sharing, and i think the following graph should be safe to reuse allocated memory.

There will be some memory usage from the graph itself, so torch.cuda.memory_reserved() would give you a more accurate indication if this memory usage is actually coming from the tensors in the graph captures.

Yes, memory can be reused, but segments can be resized, so going from smaller to larger batch sizes makes reuse more difficult. A segment cannot be “partially” reused if a tensor cannot fit in it.

One thing to check is whether the same side stream is used across all captures: graph — PyTorch 2.3 documentation

can you explain this in more details? cudagraph itself should not take gpu memory if i understand correctly.

This is an implementation detail e.g., Why does CUDAGraph itself have memory consumption? - CUDA / CUDA Programming and Performance - NVIDIA Developer Forums