GPU RAM fragmentation diagnostics

Following up on Unable to allocate cuda memory, when there is enough of cached memory, while there is no way to defrag nvidia GPU RAM, is there a way to get the memory allocation map? I’m asking in the simple context of just having one process using the GPU exclusively.

Using free memory info from nvml can be very misleading due to fragmentation, so it would be useful to be able to have some sort of estimation of fragmentation, e.g. 1GB free, 0.3G of which is in chunks less than 50MB.

And it would help to optimize code by comparing fragmentation maps before and after it’s run.

I spent quite some time reading nvidia forums but couldn’t find anything useful, perhaps someone here has knowledge of any resources that could be useful to do diagnostics for gpu ram fragmentation.

And a question about pytorch gpu ram allocation process - does pytorch have a way to choose which free segment to use? e.g. given the free memory list sequence is (a) 200MB (b) 50MB and pytorch needs to allocate 20MB - will it search for the smallest free chunk that can fit 20MB and pick (b), or will it pick the first available chunk that fits the requirement (a)?

Thank you.

3 Likes

I started working on a tool that will provide this solution.

I think I have the prototype working, but I’m stuck at not being able to emulate memory fragmentation - so I can’t test its correctness.

Is the following guaranteed to allocate a contiguous block of GPU RAM?

torch.ones((d, d)).cuda().contiguous()

I added contiguous(), but it doesn’t seem to make any difference in this case. (and checked with is_contiguous())

So here is my attempt to create a hole in memory that the next allocation request is bigger than the hole and there is no free memory block remaining of the size of that request, yet enough of total free mem to allocate it, and it should fail to allocate, due to fragmentation, but it succeeds.

# this ensures we always test the same thing
buf = leave_free_mbs(1600)
    
                   # legend: [free block]  {used block}
                   # [1601]
x1 = mem_get(512)  # {512}[1089]
x2 = mem_get(512)  # {512}{512}[577]
print(f"have {mem_free():4d}, reclaiming first 512")
del x1             # [512]{512}[577]
x3 = mem_get(1024) # shouldn't be able to allocate 1024 contiguous mem
print(f"have {mem_free():4d}")

which outputs:

consuming 4054MB to bring free mem to 1600MBs
have 1601, allocating 512
have 1089, allocating 512
have  577, reclaiming first 512
have 1089, allocating 1024
have   65

So the last call to allocate 1024MB succeeds despite supposedly having only two chunks, one of 512MB and another of ~576MB, separated by a 512MB of used chunk. If my allocation function allocates a contiguous memory, how is it then successful?

Am I doing something wrong?

Thank you.

Here is the whole program should you want to run it yourself.
Make sure to pip install nvidia-ml-py3 before you run it.

import pynvml, torch, gc

pynvml.nvmlInit()
id = torch.cuda.current_device()
def mem_free():
    gc.collect()
    torch.cuda.empty_cache()
    handle = pynvml.nvmlDeviceGetHandleByIndex(id)
    info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    return int( info.free / 2**20 )

def mem_report(): print(f"free mem={mem_free()}")

def mem_allocate_mbs(n, fatal=False): 
    " allocate n MBs, return the var holding it on success, None on failure "
    if n < 6: return None # don't try to allocate less than 6MB
    try:
        d = int(2**9*n**0.5)
        return torch.ones((d, d)).cuda().contiguous()
    except Exception as e:
        if not fatal: return None
        raise e
        
def leave_free_mbs(n):
    " consume whatever memory is needed so that n MBs are left free "
    avail = mem_free()
    assert avail > n, f"already have less available mem than desired {n}MBs"
    consume = avail - n
    print(f"consuming {consume}MB to bring free mem to {n}MBs")
    return mem_allocate_mbs(consume, fatal=True)

def globals_unset(var_names):
    " this is useful for re-running the cell, so that it resets the initial state or cleanup at the end of the cell"
    for x in var_names: 
        if x in globals(): 
            del globals()[x]
            
def mem_get(n):
    print(f"have {mem_free():4d}, allocating {n}")
    return mem_allocate_mbs(n, fatal=True)

globals_unset(['x1', 'x2', 'x3', 'buf'])
_=torch.ones(1).cuda()# preload

# this ensures we always test the same thing
buf = leave_free_mbs(1600)
    
                   # legend: [free block]  {used block}
                   # [1600]
x1 = mem_get(512)  # {512}[1092]
x2 = mem_get(512)  # {512}{512}[576]
print(f"have {mem_free():4d}, reclaiming first 512")
del x1             # [512]{512}[576]
x3 = mem_get(1024) # shouldn't be able to allocate 1024 contiguous mem
print(f"have {mem_free():4d}")

# cleanup
globals_unset(['x1', 'x2', 'x3', 'buf'])
1 Like

Well, failing to sort it out on pytorch level, I went to CUDA C-level and even there I find that cudaMalloc successfully allocates non-contiguous memory. It’s not surprising I couldn’t get it to work on pytorch-level.

So if I have two 0.5GB free fragments that are not located in sequence cudaMalloc() of 0.9GB succeeds just fine!!!

Can someone please do a sanity check? It’s a bit of a crude test, but hopefully should be easy to follow. I have an 8GB card so it was fine-tuned for it.

/* fragment gpu RAM by allocating a bunch of blocks and then releasing some in between, creating holes
   then try to allocate more than the size of the largest hole, but less than total free memory
   it appears that CUDA succeeds
   conclusiong: cudaMalloc it's not allocating contiguous memory
*/
#include <stdio.h>
#include <unistd.h>
#include <cuda.h>

const size_t Mb = 1<<20; // Assuming a 1Mb page size here

#define DSIZE0  410000000ULL //  ~400MB
#define DSIZE1 3144000000ULL // ~3000MB
#define DSIZE2  524000000ULL //  ~500MB
#define DSIZE3  630000000ULL //  ~600MB

void can_allocate() {
  size_t total;
  size_t avail;
  cudaError_t cuda_status = cudaMemGetInfo(&avail, &total);
  if ( cudaSuccess != cuda_status ) {
    printf("Error: cudaMemGetInfo fails, %s \n", cudaGetErrorString(cuda_status) );
    exit(EXIT_FAILURE);
  }

  printf("free: %.f, total %.f\n", (double)avail/Mb, (double)total/Mb);

  int *buf_d = 0;
  size_t nwords = total / sizeof(int);
  size_t words_per_Mb = Mb / sizeof(int);

  /* the only way to measure how much memory is allocatable is by trial and
     error, cudaMemGetInfo's available memory information is not reliable */
  while (cudaMalloc((void**)&buf_d,  nwords * sizeof(int)) == cudaErrorMemoryAllocation) {
    cudaFree(buf_d);
    nwords -= words_per_Mb;
    if (nwords < words_per_Mb) {
      // signal no free memory
      break;
    }
  }
  cudaFree(buf_d);
  /* clear last error */
  printf("err2:  %d\n", (int)cudaGetLastError());

  printf("can allocate:  %.fMB\n", (double)nwords/words_per_Mb);
}

int main() {
    int *d0, *d1, *d2, *d3, *d4;

    //cudaSetDevice(0);

    /* starting with 8GB free */
    /* legend: [allocated]{free} */

    // init - prealloc 500MB, including ~100MB CUDA ctx
    // [0.5]{7.5} - free total=7.5
    cudaMalloc(&d0, DSIZE0);
    printf("err1:  %d\n", (int)cudaGetLastError());

    // [0.5][0.5]{7.0} - free total=7.0
    cudaMalloc(&d1, DSIZE2);
    printf("err1:  %d\n", (int)cudaGetLastError());

    // [0.5][0.5][3]{4.0} - free total=4.0
    cudaMalloc(&d2, DSIZE1);
    printf("err2:  %d\n", (int)cudaGetLastError());

    // [0.5][0.5][3][0.5]{3.5} - free total=3.5
    cudaMalloc(&d3, DSIZE2);
    printf("err3:  %d\n", (int)cudaGetLastError());

    // [0.5][0.5][3][0.5][3]{0.5} - free total=0.5
    cudaMalloc(&d4, DSIZE1);
    printf("err2:  %d\n", (int)cudaGetLastError());

    // [0.5]{0.5}[3][0.5][3]{0.5} - free total=1.0
    cudaFree(d1);
    printf("err4:  %d\n", (int)cudaGetLastError());

    // [0.5]{0.5}[3]{0.5}[3]{0.5} - free total=1.5
    cudaFree(d3);
    printf("err4:  %d\n", (int)cudaGetLastError());

    // here we should have 1.5GB free in total, with 3 fragments of 0.5GB
    // this should say 0.5GB, but it says 1.6GB - so it allocates over fragments
    can_allocate();

    // another way to check is we shouldn't be able to allocate say 1GB of contiguous memory
    cudaMalloc(&d1, 2*DSIZE2);
    printf("err2:  %d\n", (int)cudaGetLastError());

    // sanity check 2GB at 1.5G free should fail
    // this fails, good
    cudaMalloc(&d1, 4*DSIZE2);
    printf("err2:  %d\n", (int)cudaGetLastError());

    sleep(1000);  /* keep consuming RAM */

    return 0;
}

What type of GPU are you testing on? The CUDA behavior changes between GPU architectures.

What type of GPU are you testing on? The CUDA behavior changes between GPU architectures.

GeForce GTX 1070 Ti (8GB)
nvidia driver : 410.79 + cuda 10

I think Pascal GPUs operate with 2 MB pages. (At least this appears to be the case on a P100; none of this is well documented).

I find that cudaMalloc successfully allocates non-contiguous memory.

Yes – to a certain extent. It can re-map pages, but not sub-pages. So you can allocate all your memory in chunks of the pages size, free every other allocation (so have lots of “holes”), and then allocate all the remaining space in one contiguous chunk. However, if your wholes are smaller than the page-size it won’t help. (For example, try allocating all memory in chunks of 2097153 bytes using the CUDA API. How much memory can you allocate?).

PyTorch takes advantage of this behavior by freeing all unused, cached allocations when an allocation fails. This allows the driver to remap pages to make larger contiguous chunks. (And then we retry the allocation)

However, PyTorch isn’t able to release portions of blocks that are partially used (because there’s no CUDA API to do so). The CUDA driver is more powerful than PyTorch (since it can remap pages). If our goal were only to reduce fragmentation we would just use cudaMalloc and cudaFree directly. But cudaFree synchronizes the host with the GPU, which can hurt performance.

NVIDIA is much better positioned than PyTorch to write a good caching allocator as part of the CUDA API, but they don’t seem inclined to do so. (I’ve asked).

Note that there are two-levels of potential fragmentation: fragmentation due to PyTorch and fragmentation due to the CUDA driver.

To answer some of your previous questions:

  • torch.ones((d, d)).cuda() will always allocate a contiguous block of GPU RAM (in the virtual address space)
  • Your allocation x3 = mem_get(1024) likely succeeds because PyTorch cudaFree’s x1 on failure and retries the allocation. (And as you saw, the CUDA driver can re-map pages).
  • PyTorch uses “best-fit” among cached blocks (i.e. smallest block). If there’s no block, it will try to cudaMalloc a new block.

EDIT: See THCCachingAllocator.cpp for source code

2 Likes

@colesbury, your detailed reply is extremely helpful and insightful, thank you!

So as you explained I need to track fragmentation at a much lower level. But then it won’t let me allocate memory smaller than a page if that page is already partially used, am I correct?

Why does it even bother including the small fragments in the free memory counter. In addition to the current free memory it really should have a total real free mem counter, which would exclude accounting for any partially filled pages.

How would you recommend I then try to diagnose a situation that lead me to start on this path?

RuntimeError: CUDA out of memory. 
Tried to allocate 350.00 MiB 
(GPU 0; 7.93 GiB total capacity; 5.73 GiB already allocated; 
324.56 MiB free; 1.34 GiB cached)

My idea was to write a sort of memory mapper which I would bisect/splice my code with and through its printouts find the guilty parts that caused that much fragmentation.

Does the content of this error come from CUDA or pytorch? It doesn’t look like pytorch, since its memory counters are always: allocated <= cached. So then it’s unclear what that ‘cached’ section mean.

Based on your explanation that means that there is a helluva lot of small fragments sprinkled through the RAM, so that it can’t allocate that little memory (0.35GB) in so much available (cached+free) 1.7GB of RAM.

But then it won’t let me allocate memory smaller than a page if that page is already partially used, am I correct?

No – memory smaller than a page may use a partially filled page. It just won’t combine multiple partially filled pages together.

Why does it even bother including the small fragments in the free memory counter.

I think it mostly does not count them. For example, in my experience, if you make lots of small allocations, the free memory counter will go decrease after the first small allocation and then remain unchanged for many subsequent allocations as it uses the remaining portion of the page.

How would you recommend I then try to diagnose a situation that lead me to start on this path? My idea was to write a sort of memory mapper which I would bisect/splice my code with and through its printouts find the guilty parts that caused that much fragmentation.

I would just record the sizes of the malloc and free calls to a log. (Edit THCCachingAllocator.cpp).

Does the content of this error come from CUDA or pytorch?

It comes from PyTorch. Cached in this case is cached but not allocated.

OK, then I could write an app that would go and fill out all the little nooks and crannies and then dump a summary. It sounds like it’d be the easiest to do directly in C to avoid any overheads that might get incurred by using pytorch.

I’m not sure how could that help with detecting fragmentation. Unless you mean looking at mallocs that are lesser than a page size and then accounting the remaining fragments, but that sounds like it’d be very unreliable/difficult to accomplish since it’d require an exact understanding of how CUDA does things internally. Basically this is what I did in my 2nd post, except on the macro-level.

I think this is diverging from the topic now, so I will create a new one for that and keep this one focused.

sorry for asking no PyTorch problem here. BUT met a very weird problem.
when using CUDA-OpenGL-interop, test new and delete Window and other res in a very large count Loop, after cudaGraphicsGLRegisterImage return ERROR code=2(cudaErrorMemoryAllocation) i sleep the application, then test if there was GPU RAM fragmentation problem, but not only nvidia-smi return 1043MiB / 6144MiB, but also can_allocate() (other code above were removed and only left can_allocate) is able to cudaMalloc 4772MB which is absolutely enough for the app to run.

while just using OpenGL in this Loop, no CUDA, there is still error after 60000 loop, ERROR OpenGL(1285) Out of memory, tested nvidia-smi shows OK; still not tested sleep and can_allocate at this situation.

You might want to ask this question in an OpenGL forum.

1 Like

Thanks!
is there a maximum number of graphics driver allocations of GPU memory in CUDA?(there is a limit in Vulkan, OpenGL) either on Windows, or on Linux(esp. Ubuntu)?

from: Bindless Textures - resident textures Limits - OpenGL / OpenGL: Advanced Coding - Khronos Forums
On Windows, the limited imposed by the Microsoft WDDM on your graphics driver is 4096 maximum allocations. On Linux, this limit is 4 billion. You can query this limit under Vulkan as VkPhysicalDeviceLimits::maxMemoryAllocationCount . See: