Event handler for CUDA OOM?

Is there a way to have an event handler or callback that triggers whenever (and wherever) a CUDA OOM error occurs?

For example I’d like to trigger a printout of the current node name on which the error occurs, because I know my run “fits” in VRAM, but sometimes my cluster will give me “bad nodes” and I’d like to know which nodes to exclude when I re-run the job.



I asked ChatGPT this question, and it suggested the following, haha:

import socket
import torch

def oom_callback():
    print(f"OOM error occurred on node {socket.gethostname()}")


And when I told it there was no method torch.cuda.register_cudart_error_callback, ChatGPT would respond by saying that it was introduced in a later version of PyTorch than… whichever version I had (e.g. it would say it was introduced in 1.10.0, and I replied by saying I had 1.13.0, it would say something like 'oh sorry i meant 1.14.0`, lol.)

Could you achieve this by catching the exception on the node and propagating it somehow?

import torch

  a = torch.empty(2**40, device='cuda')
except torch.cuda.OutOfMemoryError as _:
  print("propagate this OOM exception")

Well, yes but then wouldn’t you have to place copies such code throughout your entire training system?
because it my experience, once that exception happens, execution grinds to a halt.

This is how fairseq does it btw fairseq/trainer.py at 50a671f78d0c8de0392f924180db72ac9b41b801 · facebookresearch/fairseq · GitHub

1 Like

@marksaroufim thanks very much! I will give such a thing a try…