GPU memory leak in computational graph, not tensor?

Hi everyone,

after a truly horrifying debugging session today trying to find a memory leak, I discovered that in some circumstances, pytorch’s computational graphs (?) may take up memory even when all corresponding tensors are long gone. At the end of my example below, there are no more tensors allocated, yet pytorch still takes up 1GB of cuda memory.

What exactly is taking up this memory? How could I find the corresponding objects or how could I release the memory that remains allocated?

import torch
import gc

x = torch.randn(2**10,2**16, device='cuda', requires_grad = True)

print('Before')
print(f"\tAllocated memory: {torch.cuda.memory.memory_allocated()/2**20}MB")
print(f"\t'Size of x': {x.element_size() * x.nelement() / 2**20}MB")
print(torch.cuda.memory_summary())

# Change x's `grad_fn` only!
y = x.relu()**1.0 - (-x).relu()**1.0
for i in range(len(x)):
    x[i] = y[i]
del y

print('After changing x and deleting y:')
## Print all allocated tensors (There's still only 1, which is x)
n = 0
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            n+=1
            print(f"\t{n}:{type(obj)}, Size: {obj.size()}")
    except:
        pass

print(f"\tAllocated memory: {torch.cuda.memory.memory_allocated()/2**20}MB")
print(f"\t'Size of x': {x.element_size() * x.nelement() / 2**20}MB")

# now delete x
del x

print('After Deleting x:')
print(torch.cuda.memory_summary())

Output

Before
	Allocated memory: 256.0MB
	'Size of x': 256.0MB
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  262144 KB |  262144 KB |  262144 KB |       0 B  |
|       from large pool |  262144 KB |  262144 KB |  262144 KB |       0 B  |
|       from small pool |       0 KB |       0 KB |       0 KB |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |  262144 KB |  262144 KB |  262144 KB |       0 B  |
|       from large pool |  262144 KB |  262144 KB |  262144 KB |       0 B  |
|       from small pool |       0 KB |       0 KB |       0 KB |       0 B  |
|---------------------------------------------------------------------------|
| GPU reserved memory   |  262144 KB |  262144 KB |  262144 KB |       0 B  |
|       from large pool |  262144 KB |  262144 KB |  262144 KB |       0 B  |
|       from small pool |       0 KB |       0 KB |       0 KB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Allocations           |       1    |       1    |       1    |       0    |
|       from large pool |       1    |       1    |       1    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |       1    |       1    |       1    |       0    |
|       from large pool |       1    |       1    |       1    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       1    |       1    |       1    |       0    |
|       from large pool |       1    |       1    |       1    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|===========================================================================|

After changing x and deleting y:
	1:<class 'torch.Tensor'>, Size: torch.Size([1024, 65536])
	Allocated memory: 1024.0MB
	'Size of x': 256.0MB
After Deleting x:
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    1024 MB |    1792 MB |    1792 MB |     768 MB |
|       from large pool |    1024 MB |    1792 MB |    1792 MB |     768 MB |
|       from small pool |       0 MB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| Active memory         |    1024 MB |    1792 MB |    1792 MB |     768 MB |
|       from large pool |    1024 MB |    1792 MB |    1792 MB |     768 MB |
|       from small pool |       0 MB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    1792 MB |    1792 MB |    1792 MB |       0 B  |
|       from large pool |    1792 MB |    1792 MB |    1792 MB |       0 B  |
|       from small pool |       0 MB |       0 MB |       0 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Allocations           |       4    |       7    |       7    |       3    |
|       from large pool |       4    |       7    |       7    |       3    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |       4    |       7    |       7    |       3    |
|       from large pool |       4    |       7    |       7    |       3    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       7    |       7    |       7    |       0    |
|       from large pool |       7    |       7    |       7    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|===========================================================================|