Is there a way to have an event handler or callback that triggers whenever (and wherever) a CUDA OOM error occurs?
For example I’d like to trigger a printout of the current node name on which the error occurs, because I know my run “fits” in VRAM, but sometimes my cluster will give me “bad nodes” and I’d like to know which nodes to exclude when I re-run the job.
Thanks.
PS-
I asked ChatGPT this question, and it suggested the following, haha:
import socket
import torch
def oom_callback():
print(f"OOM error occurred on node {socket.gethostname()}")
torch.cuda.register_cudart_error_callback(oom_callback)
And when I told it there was no method torch.cuda.register_cudart_error_callback
, ChatGPT would respond by saying that it was introduced in a later version of PyTorch than… whichever version I had (e.g. it would say it was introduced in 1.10.0, and I replied by saying I had 1.13.0, it would say something like 'oh sorry i meant 1.14.0`, lol.)