From this blog post by @zdevito, we see that we can add a callback every time it OOMs but this callback happens for caught OOMs and retries. Is there a way to do callback only when it is a uncaught OOM?
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
dump(snapshot, open('oom_snapshot.pickle', 'wb'))
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
There isn’t an easy way to do it because the callback needs to run when the OOM happens but before the stack unwinds, otherwise a lot of memory will be released before we can take the snapshot. It is possible to record multiple snapshots and then only save the last one. Each snapshot save will take some time but it will eventually get the actual OOM snapshot:
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
global snapshot
snapshot = torch.cuda.memory._snapshot()
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
try:
<application code>
except torch.cuda.OutOfMemoryError:
print('saving allocated state during OOM')
dump(snapshot, open('oom_snapshot.pickle', 'wb'))