GPU memory reservation

Is there anyway to let pytorch reserve less GPU memory? I found it is reserving GPU memory very aggressively even for simple computation, which causes CUDA OOM for large computations. Here is some code snippet

In [1]: import torch as tc
In [2]: m = tc.nn.Sequential(
   ...:     tc.nn.Linear(1000, 1000),
   ...:     tc.nn.Linear(1000, 1000),
   ...:     tc.nn.Linear(1000, 1000),
   ...:     tc.nn.Linear(1000, 1),
   ...: ).cuda()
In [3]: x = tc.randn(1000, 1000).cuda()
In [4]: m(x).sum().backward()  # a forward pass followed by a backward pass

This causes GPU memory usage of 1285MB from nvidia-smi

|    0   N/A  N/A     31758	 C   ...-3.7.6/bin/python3.7     1285MiB |

But there are only very few tensors allocated on gpu

In [10]: import gc

In [11]: [obj.shape for obj in gc.get_objects() if tc.is_tensor(obj) and obj.device.type != 'cpu']
Out[11]:
[torch.Size([1000, 1000]),
 torch.Size([1000, 1000]),
 torch.Size([1000]),
 torch.Size([1000, 1000]),
 torch.Size([1000]),
 torch.Size([1000, 1000]),
 torch.Size([1000]),
 torch.Size([1, 1000]),
 torch.Size([1])]

I don’t want to call torch.cuda.empty_cache() since it is super expensive in training loop.
Please advice, thank you!

Here is the summary printed within pytorch, which looks smaller. But OOM is triggered as long as the usage in nvidia-smi hits the limit

In [15]: print(tc.cuda.memory_summary())
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   27378 KB |   36136 KB |   52001 KB |   24622 KB |
|       from large pool |   27345 KB |   36106 KB |   51732 KB |   24386 KB |
|       from small pool |      33 KB |      49 KB |     269 KB |     236 KB |
|---------------------------------------------------------------------------|
| Active memory         |   27378 KB |   36136 KB |   52001 KB |   24622 KB |
|       from large pool |   27345 KB |   36106 KB |   51732 KB |   24386 KB |
|       from small pool |      33 KB |      49 KB |     269 KB |     236 KB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   43008 KB |   43008 KB |   43008 KB |       0 B  |
|       from large pool |   40960 KB |   40960 KB |   40960 KB |       0 B  |
|       from small pool |    2048 KB |    2048 KB |    2048 KB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |   15629 KB |   18617 KB |   59813 KB |   44184 KB |
|       from large pool |   13614 KB |   16573 KB |   57533 KB |   43919 KB |
|       from small pool |    2015 KB |    2044 KB |    2280 KB |     265 KB |
|---------------------------------------------------------------------------|
| Allocations           |      17    |      24    |     241    |     224    |
|       from large pool |       7    |       9    |      13    |       6    |
|       from small pool |      10    |      17    |     228    |     218    |
|---------------------------------------------------------------------------|
| Active allocs         |      17    |      24    |     241    |     224    |
|       from large pool |       7    |       9    |      13    |       6    |
|       from small pool |      10    |      17    |     228    |     218    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       3    |       3    |       3    |       0    |
|       from large pool |       2    |       2    |       2    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       6    |       8    |     105    |      99    |
|       from large pool |       3    |       3    |       8    |       5    |
|       from small pool |       3    |       5    |      97    |      94    |
|===========================================================================|

The posted code would allocate ~27MB, which is also shown in your memory_summary() output:

model = nn.Sequential(
    nn.Linear(1000, 1000),
    nn.Linear(1000, 1000),
    nn.Linear(1000, 1000),
    nn.Linear(1000, 1)
)
   
print(torch.cuda.memory_allocated())
> 0

model.cuda()
print(torch.cuda.memory_allocated()/1024**2)
> 11.4609375

x = torch.randn(1000, 1000, device='cuda')
print(torch.cuda.memory_allocated()/1024**2)
> 15.27587890625

out = model(x)
print(torch.cuda.memory_allocated()/1024**2)   
> 27.64990234375

out.sum().backward()
print(torch.cuda.memory_allocated()/1024**2)
> 26.74072265625

The peak usage is also reasonable.

The rest of the allocation is caused by the CUDA context, which loads all native PyTorch CUDA kernels as well as CUDA kernels from libraries such as cuDNN, cublas etc.

Thank you for your reply :slight_smile:
Are you suggesting the rest of the allocation is CUDA libraries (which should then take constant space)?
It doesn’t seem to be the case when I increase the input size

In [1]: import torch as tc
In [2]: m = tc.nn.Sequential(
   ...:     tc.nn.Linear(1000, 1000),
   ...:     tc.nn.Linear(1000, 1000),
   ...:     tc.nn.Linear(1000, 1000),
   ...:     tc.nn.Linear(1000, 1),
   ...: ).cuda()
In [3]: x = tc.randn(100000, 1000).cuda()  # 400MB

In [7]: m(x).sum().backward()

In [8]: tc.cuda.memory_allocated()  # 1600MB, this number makes sense to me
Out[8]: 1627630080

In [9]: tc.cuda.memory_reserved()  # why double the above size?
Out[9]: 3338665984

And after this call, nvidia-smi shows even bigger memory footprint

|===============================+======================+======================|
|   0  A100-SXM4-40GB	   Off  | 00000000:10:1C.0 Off |                    0 |
| N/A   40C    P0    70W / 350W |   4430MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |

IIUC, 4 CUDA tensors of 400MB should be created (one for input, 3 for intermediate result right after each linear layer for gradient computation). And the rest should be negligible. So the total CUDA memory footprint should be 1.6GB. But tc.cuda.memory_reserved() shows 200% as much, and nvidia-smi uses even more.

Any idea?

thanks!

memory_reserved shows the allocated and cached memory. You could use the memory_summary again to check the different allocations.
nvidia-smi shows the overall memory usage. If no other processes are running, please share an executable code snippet to reproduce the memory increase seen via nvidia-smi.

You meant something like this? After running this code, CUDA memory footprint seen at nvidia-smi becomes 3.2GB

import torch as tc

m = tc.nn.Sequential(
    tc.nn.Linear(1000, 1000),
    tc.nn.Linear(1000, 1000),
    tc.nn.Linear(1000, 1000),
    tc.nn.Linear(1000, 1),
).cuda()

x = tc.randn(100000, 1000).cuda()  # 400MB

m(x).sum().backward()
+-------------------------------------------------------------------------+
| Processes:                                                              |
|  GPU   GI   CI        PID   Type   Process name              GPU Memory |
|        ID   ID                                               Usage      |
|=========================================================================|
|    0   N/A  N/A     79663      C   ...n-3.7.6/bin/python3.7     3281MiB |
+-------------------------------------------------------------------------+

Thanks!
As you can see in the memory_summary(), PyTorch reserves ~2GB so given the model size + CUDA context + the PyTorch cache, the memory usage is expected:

| GPU reserved memory   |    2038 MB |    2038 MB |    2038 MB |       0 B  |
|       from large pool |    2036 MB |    2036 MB |    2036 MB |       0 B  |
|       from small pool |       2 MB |       2 MB |       2 MB |       0 B  |

If you want to release the cache, use torch.cuda.empty_cache(). This will synchronize your code thus slowing it down, but would allow other applications to use this memory.

1 Like