How to do CUDA traces callback only for uncaught OOM errors?

There isn’t an easy way to do it because the callback needs to run when the OOM happens but before the stack unwinds, otherwise a lot of memory will be released before we can take the snapshot. It is possible to record multiple snapshots and then only save the last one. Each snapshot save will take some time but it will eventually get the actual OOM snapshot:

def oom_observer(device, alloc, device_alloc, device_free):
    # snapshot right after an OOM happened
    global snapshot
    snapshot = torch.cuda.memory._snapshot()
    
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
try:
  <application code>
except torch.cuda.OutOfMemoryError:
  print('saving allocated state during OOM')  
  dump(snapshot, open('oom_snapshot.pickle', 'wb'))
1 Like