Using `forward_pre_hook` to attribute CUDA OOMs to module execution context

While experimenting with lightweight profiling during PyTorch training, I noticed that CUDA OOMs and similar runtime errors often surface without clearly indicating which module was executing when the failure happened.

To improve this, I tried a small hook-based approach using:

  • forward_pre_hook on leaf modules

  • backward hooks to track execution phase (forward / backward)

The idea is to track the last entered module and phase, and when a failure occurs (e.g. CUDA OOM), report the execution context that was active at that time. For example:

OOM occurred during FORWARD of module: classifier
Peak GPU memory before crash: 21.4 GB

This doesn’t attempt exact tensor-level attribution, only the module execution context, but in practice it has been useful for narrowing down where memory spikes or failures originate, without running heavy profilers or re-executing the model.

Would appreciate feedback on a few points:

  1. Is this a reasonable use of forward_pre_hook for failure attribution?

  2. Are there known edge cases where this reasoning breaks down (e.g. CUDA async errors, activation checkpointing, AMP, fused ops)?

  3. Are there existing PyTorch-recommended patterns for surfacing this kind of execution-context information on failures?

I have wrapped this approach into a small utility, but the core idea is just the hook mechanism described above.

1 Like
  1. Yes, I would believe your hook approach is valid, but I would also be curious to understand what exactly you’ve implemented and how the error attribution was created.
  2. Depending on your actual implementation, you could break valid use cases. E.g. if you added synchronizing code to the hook it would break CUDA Graphs usage and would also slow down the code.
  3. Tensor subclasses often come in handy for these (debug) approaches, so you could also check if your current implementation would be compatible with these.

My approach is intentionally approximate: I register forward_pre_hooks on leaf modules to record the last entered forward context, and attach a tensor backward hook on the module output to record backward context.

The hook logic only updates a global string (EXECUTION_LAYER.current) and does not allocate tensors or synchronize CUDA.

If a CUDA OOM is raised, I report the last entered execution context (forward/backward + module name). This is meant as best-effort attribution rather than precise allocation blame, given CUDA’s asynchronous error reporting.

The exact code is here: