Hi all,
Recently I’ve been trying to understand the details of GPU memory usage in PyTorch. I know that PyTorch maintains an internal memory pool, so it’s expected that the reserved GPU memory is larger than the actually allocated memory. However, I’ve observed that even when accounting for the memory pool, the GPU memory usage reported by nvidia-smi
is still much higher than expected.
Here’s what I did:
I created a very small tensor in PyTorch (shape = 2×2, dtype = float32) and moved it to the GPU. I recorded the GPU memory usage using nvidia-smi
and torch.cuda.memory_summary()
before and after the operation. Surprisingly, nvidia-smi
showed that this small tensor transfer resulted in over 800 MB of GPU memory usage. That seems excessive. Meanwhile, torch.cuda.memory_summary()
indicated that PyTorch itself had not allocated any significant amount of memory.
Here is my Pytorch code
import torch
import subprocess
def get_gpu_memory():
result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used',
'--format=csv,nounits,noheader'])
memory_per_gpu = [int(x) for x in result.decode('utf-8').strip().split('\n')]
return memory_per_gpu[0]
# chean cache
torch.cuda.empty_cache()
before_memory = get_gpu_memory()
print(torch.cuda.memory_summary(abbreviated=True))
subprocess.run(["nvidia-smi", "-i=0"], check=True)
# do data transfer
with torch.no_grad():
torch.randn(2, 2).cuda()
# measure gpu memory usage
after_memory = get_gpu_memory()
print(f"Memory increase after operations: {after_memory - before_memory} MB")
# clean cache and measure again
torch.cuda.empty_cache()
after_empty_cache = get_gpu_memory()
print(torch.cuda.memory_summary(abbreviated=True))
subprocess.run(["nvidia-smi", "-i=0"], check=True)
print(f"Memory after empty_cache: {after_empty_cache} MB")
print(f"Persistent overhead: {after_empty_cache - before_memory} MB")
And the output is
Python Pytorch output
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Active memory | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Requested memory | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |
|---------------------------------------------------------------------------|
| Allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Active allocs | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
Mon Apr 14 10:59:55 2025
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:1A:00.0 Off | N/A |
| 30% 27C P8 28W / 350W | 2MiB / 24576MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Memory increase after operations: 828 MB
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 0 B | 512 B | 512 B | 512 B |
|---------------------------------------------------------------------------|
| Active memory | 0 B | 512 B | 512 B | 512 B |
|---------------------------------------------------------------------------|
| Requested memory | 0 B | 16 B | 16 B | 16 B |
|---------------------------------------------------------------------------|
| GPU reserved memory | 0 B | 2048 KiB | 2048 KiB | 2048 KiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 0 B | 2047 KiB | 2047 KiB | 2047 KiB |
|---------------------------------------------------------------------------|
| Allocations | 0 | 1 | 1 | 1 |
|---------------------------------------------------------------------------|
| Active allocs | 0 | 1 | 1 | 1 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 0 | 1 | 1 | 1 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 0 | 1 | 1 | 1 |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
Mon Apr 14 10:59:57 2025
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:1A:00.0 Off | N/A |
| 30% 29C P2 101W / 350W | 826MiB / 24576MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 430253 C python 824MiB |
+-----------------------------------------------------------------------------+
Memory after empty_cache: 826 MB
Persistent overhead: 826 MB
I suspected that part of the memory usage might come from CUDA context initialization. To test this, I wrote a minimal C++ program using the CUDA Driver API to perform the same task. The C++ version consumed around 255 MB of GPU memory, which I assume corresponds to the CUDA context overhead.
Here is my CUDA C++ code
#include <cuda_runtime.h>
#include <cuda.h>
#include <iostream>
#include <vector>
#include <string>
#include <iomanip>
// print GPU 0 memory usage
class GPUMemoryMonitor {
public:
GPUMemoryMonitor() {
cudaGetDeviceCount(&deviceCount);
}
void printAllGPUMemoryUsage(const std::string& label = "") {
if (!label.empty()) {
std::cout << "===== " << label << " =====" << std::endl;
}
for (int i = 0; i < 1; i++) {
printGPUMemoryUsage(i);
}
if (!label.empty()) {
std::cout << std::string(label.length() + 12, '=') << std::endl;
}
}
void printGPUMemoryUsage(int deviceId) {
size_t free, total;
cudaSetDevice(deviceId);
cudaMemGetInfo(&free, &total);
double usedMB = (total - free) / 1024.0 / 1024.0;
double totalMB = total / 1024.0 / 1024.0;
double percentUsed = 100.0 * (total - free) / total;
std::cout << "GPU " << deviceId << ": "
<< std::fixed << std::setprecision(2) << usedMB << " MB / "
<< totalMB << " MB ("
<< percentUsed << "%) used" << std::endl;
}
private:
int deviceCount;
};
int main() {
GPUMemoryMonitor memMonitor;
memMonitor.printAllGPUMemoryUsage("Initial GPU Memory State");
cudaError_t cudaStatus;
int deviceId = 0;
int deviceCount;
cudaStatus = cudaGetDeviceCount(&deviceCount);
if (cudaStatus != cudaSuccess || deviceCount == 0) {
std::cerr << "No CUDA devices found: " << cudaGetErrorString(cudaStatus) << std::endl;
return 1;
}
cudaSetDevice(deviceId);
cudaStream_t stream;
cudaStatus = cudaStreamCreate(&stream);
if (cudaStatus != cudaSuccess) {
std::cerr << "cudaStreamCreate failed: " << cudaGetErrorString(cudaStatus) << std::endl;
return 1;
}
memMonitor.printAllGPUMemoryUsage("After Stream Creation");
const size_t numElements = 2 * 2;
const size_t sizeInBytes = numElements * sizeof(float);
std::vector<float> h_data(numElements);
float* d_data = nullptr;
cudaStatus = cudaMalloc((void**)&d_data, sizeInBytes);
if (cudaStatus != cudaSuccess) {
std::cerr << "cudaMalloc failed: " << cudaGetErrorString(cudaStatus) << std::endl;
cudaStreamDestroy(stream);
return 1;
}
memMonitor.printAllGPUMemoryUsage("After Device Memory Allocation");
cudaStatus = cudaMemcpyAsync(d_data, h_data.data(), sizeInBytes, cudaMemcpyHostToDevice, stream);
if (cudaStatus != cudaSuccess) {
std::cerr << "cudaMemcpyAsync failed: " << cudaGetErrorString(cudaStatus) << std::endl;
cudaFree(d_data);
cudaStreamDestroy(stream);
return 1;
}
memMonitor.printAllGPUMemoryUsage("After cudaMemcpyAsync");
cudaStatus = cudaStreamSynchronize(stream);
if (cudaStatus != cudaSuccess) {
std::cerr << "cudaStreamSynchronize failed: " << cudaGetErrorString(cudaStatus) << std::endl;
cudaFree(d_data);
cudaStreamDestroy(stream);
return 1;
}
memMonitor.printAllGPUMemoryUsage("After Stream Synchronization");
cudaFree(d_data);
memMonitor.printAllGPUMemoryUsage("After cudaFree");
cudaStreamDestroy(stream);
memMonitor.printAllGPUMemoryUsage("After Stream Destroy");
return 0;
}
And the output is
C++ CUDA output
===== Initial GPU Memory State =====
GPU 0: 254.75 MB / 24268.31 MB (1.05%) used
====================================
===== After Stream Creation =====
GPU 0: 254.75 MB / 24268.31 MB (1.05%) used
=================================
===== After Device Memory Allocation =====
GPU 0: 258.75 MB / 24268.31 MB (1.07%) used
==========================================
===== After cudaMemcpyAsync =====
GPU 0: 258.75 MB / 24268.31 MB (1.07%) used
=================================
===== After Stream Synchronization =====
GPU 0: 258.75 MB / 24268.31 MB (1.07%) used
========================================
===== After cudaFree =====
GPU 0: 254.75 MB / 24268.31 MB (1.05%) used
==========================
===== After Stream Destroy =====
GPU 0: 254.75 MB / 24268.31 MB (1.05%) used
================================
So, after subtracting the CUDA context overhead (~255 MB, according to the C++ program), PyTorch still seems to be using about 600 MB of additional GPU memory that I cannot account for.
Here are my questions:
- Is this additional GPU memory usage by PyTorch expected or normal?
- Is there any way to trace or profile the source of GPU memory allocations? I tried using the NVIDIA Nsight suite but couldn’t get detailed memory allocation information.
- Is there any documentation or explanation for this extra ~600 MB of memory usage? I haven’t been able to find anything.
Any help or insights would be greatly appreciated. Thank you in advance!