Unexpected Extra GPU Memory Usage in PyTorch (Beyond CUDA Context and Memory Pool)

Hi all,

Recently I’ve been trying to understand the details of GPU memory usage in PyTorch. I know that PyTorch maintains an internal memory pool, so it’s expected that the reserved GPU memory is larger than the actually allocated memory. However, I’ve observed that even when accounting for the memory pool, the GPU memory usage reported by nvidia-smi is still much higher than expected.

Here’s what I did:

I created a very small tensor in PyTorch (shape = 2×2, dtype = float32) and moved it to the GPU. I recorded the GPU memory usage using nvidia-smi and torch.cuda.memory_summary() before and after the operation. Surprisingly, nvidia-smi showed that this small tensor transfer resulted in over 800 MB of GPU memory usage. That seems excessive. Meanwhile, torch.cuda.memory_summary() indicated that PyTorch itself had not allocated any significant amount of memory.

Here is my Pytorch code

import torch
import subprocess

def get_gpu_memory():
    result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used',
                                     '--format=csv,nounits,noheader'])
    memory_per_gpu = [int(x) for x in result.decode('utf-8').strip().split('\n')]
    return memory_per_gpu[0]

# chean cache
torch.cuda.empty_cache()
before_memory = get_gpu_memory()
print(torch.cuda.memory_summary(abbreviated=True))
subprocess.run(["nvidia-smi", "-i=0"], check=True)

# do data transfer
with torch.no_grad():
    torch.randn(2, 2).cuda()

# measure gpu memory usage
after_memory = get_gpu_memory()
print(f"Memory increase after operations: {after_memory - before_memory} MB")

# clean cache and measure again
torch.cuda.empty_cache()
after_empty_cache = get_gpu_memory()
print(torch.cuda.memory_summary(abbreviated=True))
subprocess.run(["nvidia-smi", "-i=0"], check=True)
print(f"Memory after empty_cache: {after_empty_cache} MB")
print(f"Persistent overhead: {after_empty_cache - before_memory} MB")

And the output is

Python Pytorch output
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Requested memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| GPU reserved memory   |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Non-releasable memory |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Allocations           |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

Mon Apr 14 10:59:55 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:1A:00.0 Off |                  N/A |
| 30%   27C    P8    28W / 350W |      2MiB / 24576MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
Memory increase after operations: 828 MB
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |    512 B   |    512 B   |    512 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |    512 B   |    512 B   |    512 B   |
|---------------------------------------------------------------------------|
| Requested memory      |      0 B   |     16 B   |     16 B   |     16 B   |
|---------------------------------------------------------------------------|
| GPU reserved memory   |      0 B   |   2048 KiB |   2048 KiB |   2048 KiB |
|---------------------------------------------------------------------------|
| Non-releasable memory |      0 B   |   2047 KiB |   2047 KiB |   2047 KiB |
|---------------------------------------------------------------------------|
| Allocations           |       0    |       1    |       1    |       1    |
|---------------------------------------------------------------------------|
| Active allocs         |       0    |       1    |       1    |       1    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       0    |       1    |       1    |       1    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       1    |       1    |       1    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

Mon Apr 14 10:59:57 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:1A:00.0 Off |                  N/A |
| 30%   29C    P2   101W / 350W |    826MiB / 24576MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    430253      C   python                            824MiB |
+-----------------------------------------------------------------------------+
Memory after empty_cache: 826 MB
Persistent overhead: 826 MB

I suspected that part of the memory usage might come from CUDA context initialization. To test this, I wrote a minimal C++ program using the CUDA Driver API to perform the same task. The C++ version consumed around 255 MB of GPU memory, which I assume corresponds to the CUDA context overhead.

Here is my CUDA C++ code

#include <cuda_runtime.h>
#include <cuda.h>
#include <iostream>
#include <vector>
#include <string>
#include <iomanip>

// print GPU 0 memory usage
class GPUMemoryMonitor {
public:
    GPUMemoryMonitor() {
        cudaGetDeviceCount(&deviceCount);
    }

    void printAllGPUMemoryUsage(const std::string& label = "") {
        if (!label.empty()) {
            std::cout << "===== " << label << " =====" << std::endl;
        }

        for (int i = 0; i < 1; i++) {
            printGPUMemoryUsage(i);
        }

        if (!label.empty()) {
            std::cout << std::string(label.length() + 12, '=') << std::endl;
        }
    }

    void printGPUMemoryUsage(int deviceId) {
        size_t free, total;
        cudaSetDevice(deviceId);
        cudaMemGetInfo(&free, &total);

        double usedMB = (total - free) / 1024.0 / 1024.0;
        double totalMB = total / 1024.0 / 1024.0;
        double percentUsed = 100.0 * (total - free) / total;

        std::cout << "GPU " << deviceId << ": "
                  << std::fixed << std::setprecision(2) << usedMB << " MB / "
                  << totalMB << " MB ("
                  << percentUsed << "%) used" << std::endl;
    }

private:
    int deviceCount;
};

int main() {
    GPUMemoryMonitor memMonitor;
    memMonitor.printAllGPUMemoryUsage("Initial GPU Memory State");

    cudaError_t cudaStatus;
    int deviceId = 0;

    int deviceCount;
    cudaStatus = cudaGetDeviceCount(&deviceCount);
    if (cudaStatus != cudaSuccess || deviceCount == 0) {
        std::cerr << "No CUDA devices found: " << cudaGetErrorString(cudaStatus) << std::endl;
        return 1;
    }

    cudaSetDevice(deviceId);

    cudaStream_t stream;
    cudaStatus = cudaStreamCreate(&stream);
    if (cudaStatus != cudaSuccess) {
        std::cerr << "cudaStreamCreate failed: " << cudaGetErrorString(cudaStatus) << std::endl;
        return 1;
    }

    memMonitor.printAllGPUMemoryUsage("After Stream Creation");

    const size_t numElements = 2 * 2;
    const size_t sizeInBytes = numElements * sizeof(float);

    std::vector<float> h_data(numElements);

    float* d_data = nullptr;
    cudaStatus = cudaMalloc((void**)&d_data, sizeInBytes);
    if (cudaStatus != cudaSuccess) {
        std::cerr << "cudaMalloc failed: " << cudaGetErrorString(cudaStatus) << std::endl;
        cudaStreamDestroy(stream);
        return 1;
    }

    memMonitor.printAllGPUMemoryUsage("After Device Memory Allocation");

    cudaStatus = cudaMemcpyAsync(d_data, h_data.data(), sizeInBytes, cudaMemcpyHostToDevice, stream);
    if (cudaStatus != cudaSuccess) {
        std::cerr << "cudaMemcpyAsync failed: " << cudaGetErrorString(cudaStatus) << std::endl;
        cudaFree(d_data);
        cudaStreamDestroy(stream);
        return 1;
    }

    memMonitor.printAllGPUMemoryUsage("After cudaMemcpyAsync");

    cudaStatus = cudaStreamSynchronize(stream);
    if (cudaStatus != cudaSuccess) {
        std::cerr << "cudaStreamSynchronize failed: " << cudaGetErrorString(cudaStatus) << std::endl;
        cudaFree(d_data);
        cudaStreamDestroy(stream);
        return 1;
    }

    memMonitor.printAllGPUMemoryUsage("After Stream Synchronization");

    cudaFree(d_data);
    memMonitor.printAllGPUMemoryUsage("After cudaFree");

    cudaStreamDestroy(stream);
    memMonitor.printAllGPUMemoryUsage("After Stream Destroy");

    return 0;
}

And the output is

C++ CUDA output
===== Initial GPU Memory State =====
GPU 0: 254.75 MB / 24268.31 MB (1.05%) used
====================================
===== After Stream Creation =====
GPU 0: 254.75 MB / 24268.31 MB (1.05%) used
=================================
===== After Device Memory Allocation =====
GPU 0: 258.75 MB / 24268.31 MB (1.07%) used
==========================================
===== After cudaMemcpyAsync =====
GPU 0: 258.75 MB / 24268.31 MB (1.07%) used
=================================
===== After Stream Synchronization =====
GPU 0: 258.75 MB / 24268.31 MB (1.07%) used
========================================
===== After cudaFree =====
GPU 0: 254.75 MB / 24268.31 MB (1.05%) used
==========================
===== After Stream Destroy =====
GPU 0: 254.75 MB / 24268.31 MB (1.05%) used
================================

So, after subtracting the CUDA context overhead (~255 MB, according to the C++ program), PyTorch still seems to be using about 600 MB of additional GPU memory that I cannot account for.

Here are my questions:

  • Is this additional GPU memory usage by PyTorch expected or normal?
  • Is there any way to trace or profile the source of GPU memory allocations? I tried using the NVIDIA Nsight suite but couldn’t get detailed memory allocation information.
  • Is there any documentation or explanation for this extra ~600 MB of memory usage? I haven’t been able to find anything.

Any help or insights would be greatly appreciated. Thank you in advance!

I see a memory usage around ~260MB which corresponds to the CUDA context, the lazily loaded CUDA kernels, as well as the allocations performed by PyTorch.

I cannot reproduce the reported 600MB increase and don’t know why you would see it unless you are using a really old PyTorch release with a CUDA version before 11.7 or are explicitly disabling CUDA’s lazy module loading.
In all PyTorch binaries with CUDA 11.7+ we enabled lazy module loading, which loads only needed kernels lazily into the context instead of all kernels (which was creating a large overhead).

Hi @ptrblck, thank you for being willing to reproduce this issue! I’m glad to hear you confirmed that approximately 260MB of GPU memory usage should be caused by CUDA context and lazy-loaded kernels.

Regarding the additional ~600 MB of GPU memory consumption that you couldn’t reproduce, I can provide more details. My experimental environment is as follows:

  • PyTorch version and corresponding CUDA compilation version: 2.1.0+cu118 11.8 (python -c "import torch; print(torch.__version__, torch.version.cuda)")
  • Operating system: Rocky Linux 8.7 (cat /etc/os-release)
  • GPU driver version: 520.61.05 (nvidia-smi)
  • System-installed CUDA version: 11.8 (nvidia-smi)
  • GPU model: 8 NVIDIA GeForce RTX 3090 GPUs attached to the same server

Additionally, I don’t believe there are any extra kernels being loaded. I recorded the CUDA kernel call sequences for both the Python code and C++ code using Nsight Systems CLI, and throughout the entire program execution I only observed the following CUDA kernel launched:

cudaMemcpyAsync
cudaFree
cudaMalloc
cuGetProcAddress
cudaStreamSynchronize
cudaStreamIsCapturing_v10000
cuModuleGetLoadingMode
cuInit

At a glance, there doesn’t seem to be anything particularly suspicious. Is there any other information I haven’t mentioned that might be relevant? Thank you again.

Are you seeing the same additional memory usage in the latest nightly release?

Hi @ptrblck , thank you for your suggestion! I tried the latest PyTorch release version (2.6.0), but the issue still persists. However, your reminder prompted me to try running it on a server with a newer NVIDIA driver version (550.90.07), and I got a result with only ~260 MB of VRAM usage. I believe this issue might be more related to subtle details with the NVIDIA drivers. Anyway, thank you for your response!