While experimenting with lightweight profiling during PyTorch training, I noticed that CUDA OOMs and similar runtime errors often surface without clearly indicating which module was executing when the failure happened.
To improve this, I tried a small hook-based approach using:
-
forward_pre_hookon leaf modules -
backward hooks to track execution phase (forward / backward)
The idea is to track the last entered module and phase, and when a failure occurs (e.g. CUDA OOM), report the execution context that was active at that time. For example:
OOM occurred during FORWARD of module: classifier
Peak GPU memory before crash: 21.4 GB
This doesn’t attempt exact tensor-level attribution, only the module execution context, but in practice it has been useful for narrowing down where memory spikes or failures originate, without running heavy profilers or re-executing the model.
Would appreciate feedback on a few points:
-
Is this a reasonable use of
forward_pre_hookfor failure attribution? -
Are there known edge cases where this reasoning breaks down (e.g. CUDA async errors, activation checkpointing, AMP, fused ops)?
-
Are there existing PyTorch-recommended patterns for surfacing this kind of execution-context information on failures?
I have wrapped this approach into a small utility, but the core idea is just the hook mechanism described above.