How to do CUDA traces callback only for uncaught OOM errors?

From this blog post by @zdevito, we see that we can add a callback every time it OOMs but this callback happens for caught OOMs and retries. Is there a way to do callback only when it is a uncaught OOM?

def oom_observer(device, alloc, device_alloc, device_free):
    # snapshot right after an OOM happened
    print('saving allocated state during OOM')
    snapshot = torch.cuda.memory._snapshot()
    dump(snapshot, open('oom_snapshot.pickle', 'wb'))

torch._C._cuda_attach_out_of_memory_observer(oom_observer)

looking at the C++ code, i don’t see an obvious way to add a flag for only when it is uncaught OOMs cuda/module.cpp CUDACachingAllocator.cpp

There isn’t an easy way to do it because the callback needs to run when the OOM happens but before the stack unwinds, otherwise a lot of memory will be released before we can take the snapshot. It is possible to record multiple snapshots and then only save the last one. Each snapshot save will take some time but it will eventually get the actual OOM snapshot:

def oom_observer(device, alloc, device_alloc, device_free):
    # snapshot right after an OOM happened
    global snapshot
    snapshot = torch.cuda.memory._snapshot()
    
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
try:
  <application code>
except torch.cuda.OutOfMemoryError:
  print('saving allocated state during OOM')  
  dump(snapshot, open('oom_snapshot.pickle', 'wb'))
1 Like