Capture CUDAGraph without using double the memory

Hi!

I can successfully capture the CUDAGraph and replay. I took the API example from this blog and modified it for my own model. Basically, I can forward and backward run my model normally with a certain batch size (25). But when I need to capture the graph, the reserved memory suddenly doubles, making me only able to do batch size 8.

...
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

train_x, train_y = data_generator(T, seq_len, n_train)
model = TCN(1, n_classes, channel_sizes, kernel_size, dropout=dropout).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print('Before stream s, allocated: ', torch.cuda.memory_allocated('cuda'), '; reserved: ', torch.cuda.memory_reserved('cuda'))
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        inputs = train_x[batch_size*i:batch_size*(i+1)].unsqueeze(1).contiguous().to(device)
        labels = train_y[batch_size*i:batch_size*(i+1)].to(device)
        optimizer.zero_grad(set_to_none=True)
        y_pred = model(inputs)
        loss = criterion(y_pred.view(-1, n_classes), labels.view(-1))
        loss.backward()
        optimizer.step()
torch.cuda.current_stream().wait_stream(s)

print('After stream s before graph g, allocated: ', torch.cuda.memory_allocated('cuda'), '; reserved: ', torch.cuda.memory_reserved('cuda'))
print('Start capturing CUDA graph.')

# capture
g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
optimizer.zero_grad(set_to_none=True)
static_inputs = train_x[:batch_size].unsqueeze(1).contiguous().to(device)
static_labels = train_y[:batch_size].to(device)
with torch.cuda.graph(g):
    static_y_pred = model(static_inputs)
    static_loss = criterion(static_y_pred.view(-1, n_classes), static_labels.view(-1))
    static_loss.backward()
    optimizer.step()
    
print('After graph g before replay, allocated: ', torch.cuda.memory_allocated('cuda'), '; reserved: ', torch.cuda.memory_reserved('cuda'))

for i in range(30):
    inputs = train_x[batch_size*i:batch_size*(i+1)].unsqueeze(1).contiguous().to(device)
    labels = train_y[batch_size*i:batch_size*(i+1)].to(device)
    static_inputs.copy_(inputs)
    static_labels.copy_(labels)
    g.replay()

print('After replay, allocated: ', torch.cuda.memory_allocated('cuda'), '; reserved: ', torch.cuda.memory_reserved('cuda'))

The output is:

Before stream s, allocated:  2978816 ; reserved:  4194304
After stream s before graph g, allocated:  13397522944 ; reserved:  15101591552
Start capturing CUDA graph.
After graph g before replay, allocated:  13400187392 ; reserved:  30045896704
After replay, allocated:  13400187392 ; reserved:  30045896704

If I try a batch size larger (9, 25, etc), it works up to capturing the graph. The error output is:

Traceback (most recent call last):
  File "/gpfs/fs1/home/minzhao.liu/TCN/cuda_graph.py", line 123, in <module>
    optimizer.step()
  File "/home/minzhao.liu/.conda/envs/qtensor-torch/lib/python3.9/site-packages/torch/cuda/graphs.py", line 149, in __exit__
    self.cuda_graph.capture_end()
  File "/home/minzhao.liu/.conda/envs/qtensor-torch/lib/python3.9/site-packages/torch/cuda/graphs.py", line 71, in capture_end
    super(CUDAGraph, self).capture_end()
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Also, can I save the graph and run it on a different machine? If anyone can recommend some resources and more examples of CUDA graphs with PyTorch, I would really appreciate since I have only found a few minimal examples.

What jumps out at me is

Before stream s, allocated:  2978816...
After stream s before graph g, allocated:  13397522944

The model evidently isn’t that big (about 3 MB) but somehow after warmup the allocated memory shoots up to 13 GB.
I’m not surprised the reserved memory gets pushed up to ~15 GB (it increases to accommodate the high-water mark of the built graph). But after backward(), the graph should be torn down, and the allocated memory should settle back to whatever the model params and optimizer states need which should be several times the model’s param set size, in other words 3 MB x N where N is a small integer.

Can you add

gc.collect()
torch.cuda.empty_cache()

just above

print('After stream s before graph g, allocated: ', torch.cuda.memory_allocated('cuda'), '; reserved: ', torch.cuda.memory_reserved('cuda'))

and report the values?

can I save the graph and run it on a different machine

No, there’s currently no way to serialize and deserialize graphs.

more examples of CUDA graphs with PyTorch

The master docs are the best resource.
https://pytorch.org/docs/master/notes/cuda.html#cuda-graphs
Some internal people are adding cuda graphs to https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch, but I don’t know how far that effort has progressed.

Thank you so much for the suggestions! After adding your two lines of code, the allocated memory is still:

After stream s before graph g, allocated:  13397522944 ; reserved:  14986248192.

I think I need to figure out why the memory is not clearing. Is this the reason why the reserved memory doubles during capture graph?

The model is a quantum simulation model, which generates enormous tensors and uses torch.einsum everywhere with complicated python logic. This is why it gets so large. The simulator itself is a black box and I gotta dig deeper.

I solved the problem. It was because there are tensors in lists and dictionaries. After del list, it was able to release the memory after backward passes.