Measuring peak memory usage: tracemalloc for pytorch?

I’ve been working on tools for memory usage diagnostics and management (ipyexperiments ) to help to get more out of the limited GPU RAM. The features include tracking real used and peaked used memory (GPU and general RAM). The peak memory usage is crucial for being able to fit into the available RAM.

Initially, I was spinning off a thread that recorded peak memory usage while the normal process runs.

Then I discovered that I can use python’s tracemalloc to measure the allocated general RAM, but more importantly peak memory usage:

import tracemalloc
tracemalloc.start()
do_someting_that_consumes_ram_and releases_some()
# show how much RAM the above code allocated and the peak usage
current, peak =  tracemalloc.get_traced_memory()
print(f"{current:0.2f}, {peak:0.2f}")
tracemalloc.stop()

So I no longer need the monitor thread for tracking general RAM but still need it for GPU RAM. Perhaps there is a way to have a similar to python’s tracemalloc functionality in pytorch? If not, and it’s something feasible, perhaps this can be a feature request? I’m not asking for all of tracemalloc’s features, but just the limited capacity to measure allocated/free gpu ram, as demonstrated in the example above.

Thank you.

1 Like

See https://pytorch.org/docs/master/notes/cuda.html#memory-management

I fail to see how your link answers my question, @SimonW. I wasn’t asking about pytorch memory management, I’m asking about tracemalloc-like functionality.

Currently to get the peak GPU RAM used by pytorch, I need to:

  1. Start a thread that monitors gpu used memory every few msecs
  2. Run the real code in the main process
  3. Stop the thread and collect the results.

It works, but it’s awkward and imprecise. And, of course, I have to clear cache to get the actual memory usage.

1 Like

From the link:

You can use memory_allocated() and max_memory_allocated() to monitor memory occupied by tensors, and use memory_cached() and max_memory_cached() to monitor memory managed by the caching allocator.

I wish it were so.

I’m using nvml to get the memory readings and there is no match with torch memory functions reports.


First, the cache: If I run torch.cuda.empty_cache(), shouldn’t these 2 returns 0’s?
torch.cuda.memory_cached(), torch.cuda.max_memory_cached()
they don’t. Why?

Calling empty_cache() can release all unused cached memory from PyTorch so that those can be used by other GPU applications.

“can” == “will”? that’s very ambiguous if it can’t state exactly what it does.

I understand that probably torch.cuda.max_memory_cached() measures the highest cache ever allocated during a given process and it never goes down. It’s of a low utility, but fair enough.

Why doesn’t then torch.cuda.memory_cached() report 0 after the cache was cleared?


Then comes torch.cuda.max_memory_allocated(), which again is of a very little utility, since it appears to be showing the max ever memory allocation, so once you hit some peak, you no longer can measure any peaks below the highest peak.

Let’s compare the torch memory reports with nvml reports in MBs (I call torch.cuda.empty_cache() before each measurement to kill the caching effects - and hoping that it actually does what it advertises to do).

Report columns:
A torch.cuda.memory_allocated() (delta)
B torch.cuda.max_memory_allocated() (delta)
C nvml’s used memory (delta)
D nvml’s peak memory measure via a peak monitor thread (delta)

          A    B    C     D
epoch 1  25 6050    80	6220
epoch 2   0    0     0	 924
epoch 3   0    0     0	 924

Note how column B is stuck at the same number, whereas column D shows that the peak memory used is different between the first epoch and the rest. I double checked that by watching the live nvidia-smi outputs. The GPU RAM consumption peaks at around 7GB during the first run and only at 1.6GB during consequent runs.

What pytorch needs is to have a stop/start method to reset the max measurements, like tracemalloc does. Otherwise, it’s like a depth gauge, where it only indicates the deepest depth one has gone to, but gives zero indication on the actual depth profile.

You can also see from the little table, that the actual GPU memory usage is larger than what pytorch reports. That means that pytorch isn’t reporting the extras the gpu card allocating beyond what pytorch requested. Albeit the difference is not huge.

Also, to clarify, I could easily measure peak memory usage if I were to deal with low-level code. Just take a snapshot of used memory before and after. What I’m trying to profile is top-level code that does internally a whole bunch of various allocations and deallocations.

So at the moment using a peak memory monitoring thread seems to be the only viable solution for measuring intermediate steps. And staying away from pytorch’s memory reporting functions seems to help as well, going directly for nvml reports, getting it directly from horse’s mouth, because that really is the only thing that counts when it comes to OOM problems.

Please don’t get me wrong, I’m not saying that pytorch’s memory reporting functions are incorrect. I’m just saying that they show only part of the picture and more detailed views are needed to be able to tune up one’s code more efficiently.

We are talking about a single process using single GPU setup here, so there is no interference from other processes.

I just hope that torch.cuda.empty_cache really works by releasing all of its cache. Since if it doesn’t then the measurements are incorrect.

edit: I changed the numbers in the table to deltas (before and after), so it’s easier to make sense out of it.

1 Like

And for some of you code and hard numbers might speak better than words, so here is a simplified typical situation where some RAM gets allocated and some gets released:

def consume_gpu_ram(n): return torch.ones((n, n)).cuda()
def consume_gpu_ram_256mb(): return consume_gpu_ram(2**13)

# should be: 256 used, 512 peaked
c1 = consume_gpu_ram_256mb()
c2 = consume_gpu_ram_256mb()
del c1

I want to have a tool that will tell me that when this program has run, the memory consumption peaked at 512MB and then finally stabilized at 256MB (we are talking deltas here). I hope it’s clear to see:

start:                 -> using   0MB
line 1: allocate 256MB -> using 256MB
line 2: allocate 256MB -> using 512MB
line 3: free 256MB     -> using 256MB

There is no way current pytorch tools can tell me that there was a peak of 512MB there, unless it’s the very first code that was run and or no other previously run code ever consumed more RAM than this program requires. So here is a better example:

# part 1: render torch.cuda.max_memory_allocated() useless for future peak computations 
# in lower ranges by allocating a much bigger chunk of RAM and then freeing it.
z = [consume_gpu_ram_256mb() for i in range(4)] # 1GB
del z

# part 2: now measure:
# should be: 256 used, 512 peaked
c1 = consume_gpu_ram_256mb()
c2 = consume_gpu_ram_256mb()
del c1

torch.cuda.max_memory_allocated() will now report 1024, when part 2 consumed only 512MB at its peak.

And here is a simple program that shows how cuda tools are lacking, and how the workaround with peak measuring thread does measure the correct things (with a possible small error due to the thread’s unpredictable timing).

You will need pynvml installed: pip/conda install nvidia-ml-py3

import threading, torch, time, pynvml

def preload_pytorch():
    torch.ones((1, 1)).cuda()

def gpu_mem_used(id):
    handle = pynvml.nvmlDeviceGetHandleByIndex(id)
    info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    return int(info.used/2**20)

def gpu_mem_used_no_cache(id):
    torch.cuda.empty_cache()
    return gpu_mem_used(id)

def peak_monitor_start():
    global peak_monitoring
    peak_monitoring = True

    # this thread samples RAM usage as long as the current epoch of the fit loop is running
    peak_monitor_thread = threading.Thread(target=peak_monitor_func)
    peak_monitor_thread.daemon = True
    peak_monitor_thread.start()

def peak_monitor_stop():
    global peak_monitoring
    peak_monitoring = False

def peak_monitor_func():
    global nvml_peak, peak_monitoring
    nvml_peak = 0
    id = torch.cuda.current_device()

    while True:
        nvml_peak = max(gpu_mem_used(id), nvml_peak)
        if not peak_monitoring: break
        time.sleep(0.001) # 1msec

def consume_gpu_ram(n): return torch.ones((n, n)).cuda()
def consume_gpu_ram_256mb(): return consume_gpu_ram(2**13)

peak_monitoring = False
nvml_peak = 0
preload_pytorch()
pynvml.nvmlInit()
id = torch.cuda.current_device()

# push the pytorch's peak gauge high up and then release the memory
z = [consume_gpu_ram_256mb() for i in range(4)] # 1GB
del z

peak_monitor_start()
nvml_before = gpu_mem_used_no_cache(id)
cuda_before = int(torch.cuda.memory_allocated()/2**20)

# should be: 256 used, 512 peaked
c1 = consume_gpu_ram_256mb()
c2 = consume_gpu_ram_256mb()
del c1

# code finished
peak_monitor_stop()
nvml_after = gpu_mem_used_no_cache(id)
cuda_after = int(torch.cuda.memory_allocated()/2**20)
cuda_peak  = int(torch.cuda.max_memory_allocated()/2**20)
print("nvml:", nvml_after-nvml_before, nvml_peak-nvml_before)
print("cuda:", cuda_after-cuda_before, cuda_peak-cuda_before)

Output:

nvml: 256 512
cuda: 256 1024

cuda tools can’t give me the right answer of 512MB here.

Now look at the simplicity of tracemalloc doing the same thing:

import tracemalloc, numpy as np

def consume_cpu_ram(n): return np.ones((n, n))
def consume_cpu_ram_128mb(): return consume_cpu_ram(2**12)

# push the process' peak gauge high up and then release the memory
z = [consume_cpu_ram_128mb() for i in range(8)] # 1GB
del z

tracemalloc.start()

# expecting peak requirements of 256MB, and final 128MB
a1 = consume_cpu_ram_128mb()
a2 = consume_cpu_ram_128mb()
del a1

cpu_current, cpu_peak = list(map(lambda x: int(x/2**20), tracemalloc.get_traced_memory()))
tracemalloc.stop()

print(cpu_current, cpu_peak)

Output:

128 256

pytorch could do exactly the same, here is some pseudo-code:

class max_memory_allocated_local():
    def start(self):
        self.begin = memory_allocated()
        max_memory_allocated_local_reset() # put the peak gauge to zero
    def stop(self):
        self.end  = memory_allocated()
        self.peak = max_memory_allocated_local()
    def get_traced_memory(self):    
        return self.end-self.begin, self.peak-self.begin

it introduces a local max_memory_allocated_local, which can be reset by the user, but otherwise working the same as max_memory_allocated().

1 Like

The section is just a faq, if you click on the function names it leads you to the doc pages, which are very clear and what the functions do. There is no need to guess.

These numbers naturally don’t match exactly as nvidia-smi because of cuda context.

As said in the page, empty_cache only frees unused cached memory. If a memory block in the caching allocator backs any tensor, it won’t get freed. So memory_cached is not 0 after you call empty_cache.

I agree that resetting methods will be useful. Please submit a feature request.

These numbers naturally don’t match exactly as nvidia-smi because of cuda context.

Which renders the pytorch method memory_allocated() not so useful, because it doesn’t help to know how much memory was really allocated, if one is trying to calculate their hyperparameters based on available free-memory. But I suppose there is no way pytorch could even try to estimate the extras that it can’t account for. It sounds like something is wrong in that logic, but unfortunately I don’t know anything about CUDA internals to tell why it can’t communicate the client how much memory it allocated so that that memory_allocated could match the real thing. So I guess direct access to the the gpu via nmvl is really the only method that I know of that can be used for precise measurements, and pytorch’s memory_allocated is more of an estimate.

I agree that resetting methods will be useful. Please submit a feature request.

Done: user-resettable torch.cuda.max_memory_allocated() · Issue #15968 · pytorch/pytorch · GitHub

Thank you for following up on this thread, @SimonW

You are right. Unfortunately we can’t know the exact size of the cuda context. The memory_allocated/cached methods report all memory allocated by pytorch. But if cuda context is out of our control. But, I think you can safely assume that cuda context is 100~250MB.

I submitted a patch to add those methods. Could you check the PR (linked in the issue) and let me know if they satisfy your needs?

1 Like

Are you talking about the context of loading cuda, the first time it’s used? I get a pretty consistent 0.5GB per process with cuda100.

If you’re talking about cuda context allocated during the actual use of cuda, how can it be a relatively fixed number, regardless of memory usage? clearly, I know very little about it cuda context, and can only talk about the actual usage numbers I see. Do you have a recommendation for a concise description/guide to read to understand cuda context w/o a need to study something huge?

Thank you.

It solves the problem. Thank you very much.
Here is a use case that others can use and feel free to add it to the test suite:

import torch
def consume_gpu_ram(n): return torch.ones((n, n)).cuda()
def consume_gpu_ram_256mb(): return consume_gpu_ram(2**13)

def b2mb(x): return int(x/2**20)
class TorchTracemalloc():

    def __enter__(self):
        self.begin = torch.cuda.memory_allocated()
        torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
        return self

    def __exit__(self, *exc):
        self.end  = torch.cuda.memory_allocated()
        self.peak = torch.cuda.max_memory_allocated()
        self.used   = b2mb(self.end-self.begin)
        self.peaked = b2mb(self.peak-self.begin)
        print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")

# push the process' peak gauge high up and then release all the memory
# expecting 0 used / 1024 peaked
with TorchTracemalloc() as tt:
    z = [consume_gpu_ram_256mb() for i in range(4)] # 1GB
    del z
assert tt.used == 0 and tt.peaked == 1024

# allocate, allocate, release half
# expecting 256 used / 512 peaked
with TorchTracemalloc() as tt:
    # should be: 256 used, 512 peaked
    c1 = consume_gpu_ram_256mb()
    c2 = consume_gpu_ram_256mb()
    del c1
assert tt.used == 256 and tt.peaked == 512
del c2 # reset for next test

# allocate, allocate, release all
# expecting 0 used / 512 peaked
with TorchTracemalloc() as tt:
    # should be: 0 used, 512 peaked
    c1 = consume_gpu_ram_256mb()
    c2 = consume_gpu_ram_256mb()
    del c1, c2
assert tt.used == 0 and tt.peaked == 512

# allocate, don't release
# expecting 1536 used / 1536 peaked
with TorchTracemalloc() as tt:
    z = [consume_gpu_ram_256mb() for i in range(6)]
assert tt.used == 1536 and tt.peaked == 1536
del z # reset for next test

asserts already do the checks, but there is a visual check as well, output:

delta used/peak    0/1024
delta used/peak  256/ 512
delta used/peak    0/ 512
delta used/peak 1536/1536

Yes. Each cuda context is fixed size. For PyTorch, we use something called primary ctx, which is unique per process.
The rest of 0.5GB comes from the caching allocator. Basically when you first allocates cuda memory, if it is smaller than what we call a “block” (which I think is 256MB), we allocate a whole block of memory, and cache the rest. Also, using random sampling methods for the 1st time also allocates additional cuda rng state, but that is also one time and fixed size.

Edit: The above info is not accurate. See follow up below.

1 Like

Basically when you first allocates cuda memory, if it is smaller than what we call a “block” (which I think is 256MB), we allocate a whole block of memory, and cache the rest.

I’m trying to see what you’re saying by measuring it and it doesn’t match up:

import torch 
import pynvml
pynvml = pynvml
pynvml.nvmlInit()
preload = 0
prev = 0
def nvml_used():
    handle = pynvml.nvmlDeviceGetHandleByIndex(torch.cuda.current_device())
    info   = pynvml.nvmlDeviceGetMemoryInfo(handle)
    return b2mb(info.used)
def b2mb(x): return int(x/2**20)
def consume_gpu_ram(n): return torch.ones((n, n)).cuda()
def consume_gpu_ram_64mb():  return consume_gpu_ram(2**12)
def consume_gpu_ram_256mb(): return consume_gpu_ram(2**13)
def mem(): 
    global preload, prev
    this = nvml_used()
    delta_cached = this - preload
    delta_used = this - prev
    prev = this
    print(f"   nvml used: {delta_used:4d}, allocated: {delta_cached:4d}")
    print(f"pytorch used: {b2mb(torch.cuda.memory_allocated()):4d}, allocated: {b2mb(torch.cuda.memory_cached()):4d}\n")

print("preloading:")
mem()
_ = torch.ones((1, 1)).cuda()
mem()
preload = nvml_used()

print("\nrunning:")
x1 = consume_gpu_ram_64mb()
mem()
x2 = consume_gpu_ram_256mb()
mem()
del x2
mem()
x3 = consume_gpu_ram_64mb()
mem()

So we have two parts to the program, the first one forces the loading of all that’s needed to run cuda, which allocates 0.5GB on the card and then it does a few allocations. And we measure torch.cuda’s allocated and cached and compare it to nvml numbers of allocated - I rigged the nvml_used() function to match torch.cuda.memory_cached()'s behavior by subtracting the memory allocated right after preloading happened, hence the different behavior during and after preloading.

nvml used reports don’t report the right thing once memory gets reported and caching kicks in. empty_cache needs to be called to get it right, but then it’d mess up with memory_cached repors which this code is exploring, so you can ignore nvml_used numbers.

Running it:

preloading:
   nvml used:   10, allocated:   10
pytorch used:    0, allocated:    0

   nvml used:  495, allocated:  505
pytorch used:    0, allocated:    1


running:
   nvml used:   64, allocated:   64
pytorch used:   64, allocated:   65

   nvml used:  256, allocated:  320
pytorch used:  320, allocated:  321

   nvml used:    0, allocated:  320
pytorch used:   64, allocated:  321

   nvml used:    0, allocated:  320
pytorch used:  128, allocated:  321

First, where is that block of cache you were referring to? When 0.5 GB is consumed pytorch reports 1MB of cache as compared to you saying that it should have allocated a block of about 256MB in cache.

Second, any reason why memory_allocated isn’t including those 0.5GB? Surely, for a user who isn’t going to dig into details this would appear very confusing. If she starts with a card not used by any process and allocates a tensor of 100MB, they would see it consuming 600MB, but reported only 100MB by pytorch.

Finally, from this report I understand that I didn’t original get what memory_cached() really is. memory_cached isn’t showing how much memory pytorch has cached that’s not currently used. It just says how much memory pytorch pre-allocated, but only some of it it may be free. So really to see how much available to use memory in cache is memory_cached()-memory_allocated().

So my only remaining puzzle to understand is whether any of the 0.5GB allocated on the first cuda call is ever available to the user for allocating tensors, because currently it appears that when I use pytorch my card’s RAM is really TOTALRAM-0.5GB*num_of_processes. So If 2 processes use the card, I lose 1GB (!) of RAM.

Thank you.

edit: in case you decide to study the code do note that it has an issue of reporting delta used for nvml, but absolute used (allocated) for pytorch.cuda. So the first column (used) is somewhat bogus. while it’s easy to fix pytorch used to report delta, nvml can’t report the correct numbers because of caching here, so really just ignore the first column.

Apologies. I went ahead and check the code. The smallest block size is 1MB. So the memory usage you see is not related to that.

However, it is still very possible that most of 0.5GB is CUDA ctx. The actual size depends on the CUDA version & card type (reference). I checked with a 1080Ti & CUDA 8 I have access to, it takes 343MB. And from my memory on a Titan Xp and CUDA 9 it took ~200MB.

Google search shows other people reporting a ctx as large as 500MB: https://stackoverflow.com/questions/12109985/cudasetdevice-allocates-more-than-580-mb-of-global-memory/14101969.

You can check by running another CUDA program and see if it also allocates ~500MB on the first CUDA call.

Btw, in your code, do you mean to use memory_cached rather than max_memory_cached?

OK, so 1MB it is. We are on the same page now.

However, it is still very possible that most of 0.5GB is CUDA ctx. The actual size depends on the CUDA version & card type (reference). I checked with a 1080Ti & CUDA 8 I have access to, it takes 343MB. And from my memory on a Titan Xp and CUDA 9 it took ~200MB.

Fascinating - the larger the card the less overhead! Mine is 8GB and is taking 0.5GB.

Now I know it’s card-dependent. thank you!

Btw, in your code, do you mean to use memory_cached rather than max_memory_cached ?

Yes, you’re correct. fixed. Thank you.

Going back to the initial question - would it be difficult to add another function that measures the allocated memory but is resettable too?

Here is the tracemalloc.get_traced_memory() sample again:

import tracemalloc
tracemalloc.start()
do_someting_that_consumes_ram_and releases_some()
# show how much RAM the above code allocated and the peak usage
current, peak =  tracemalloc.get_traced_memory()
print(f"{current:0.2f}, {peak:0.2f}")
tracemalloc.stop()

So with your patch we now have peak covered, by resetting torch.cuda.reset_max_memory_allocated() before a code to be profiled is run. But the only way to measure the consumed memory by the same code currently is to force clear_cache before and after the code is run. Note, consumed measures how much allocated minus freed a given piece of code generated, disregarding the caching mechanism.

Except, of course, it must leave torch.cuda.memory_allocated() alone, since we don’t want to break that functionality. so it’ll require 2 functions, one that measures and one that that resets it.

1 Like

I don’t have access to the Xp now so I could remembered wrong :).

What do you mean by “measure the consumed memory by code”? If you are looking for how large are the new tensors created by running this segment of code, can’t you just look at the difference of memory_allocated?

1 Like

If you are looking for how large are the new tensors created by running this segment of code, can’t you just look at the difference of memory_allocated ?

You’re correct, @SimonW. It works just fine. For some reason I thought caching would affect the numbers, but it doesn’t. I verified it with some more tests.

So I am all set, just need to wait for pytorch-1.0.1 to be out.

Thank you, again.

@SimonW, I have been thinking about the solution you implemented and there is a need for scoped-memory measurements, where scopes can overlap or be nested.

Scenario 1: an application relying on the normal functioning (pre-reset implementation) max_memory_allocated or max_memory_cached could now malfunction if some other application resets either or both (action a a distance).

Scenario 2: two profilers measuring different scopes. Say one measuring at a function level, another at a wider or narrower scope. Since there is only one counter they will be resetting each other’s measurements. python’s tracemalloc has the same issue, since it doesn’t create an object instance for the counter.

The 2nd scenario is not hypothetical, it’s actually a need I have right now as I have different profilers measuring different scopes. There are in different applications so they can’t really communicate with each other to keep each other in sync with reset calls. e.g. I have one profiler running on the train loop epoch-level, another on the jupyter cell-level, yet another on larger parts of the notebook. And unfortunately, my current measuring thread approach is clearly failing to catch all peaks :frowning: so it’d be extremely helpful to be able to switch to use max_memory_allocated yet different instances of it in different scopes.

So I need to be able to do something like:

max_obj1 = MaxMemoryAllocated()
# run some code 1
for epoch in epochs:
    max_obj2 = MaxMemoryAllocated()
    # run epoch code
    peak_epoch = max_obj2.peak()
# run some code ...
peak = max_obj1.peak()
del max_obj1

Of course, those would be unrelated applications, this code sample is just demonstrating how their execution will overlap and why the current implementation is insufficient.

and also I’m missing torch.cuda.free_memory or total_memory? How can I know from torch.cuda how much memory do I have available (or total, from which I could derived free)?

Thank you.