I just open-sourced torchdiag, a small diagnostic toolkit for PyTorch models. It brings SRE-style observability to model training with five commands:
- torchdiag.summary(model) — parameter counts, memory footprint, device placement
- torchdiag.check_gradients(model) — flags vanishing, exploding, or disconnected gradients
- torchdiag.check_dead_neurons(model, x) — detects dead ReLU neurons per layer
- torchdiag.verify_step(model, opt, loss, x, y) — runs one training step and verifies everything works
- torchdiag.memory_report() — GPU/CPU/MPS memory snapshot
Install: pip install torchdiag
GitHub: GitHub - AddyM/torchdiag: PyTorch model health diagnostics — gradient checks, dead neuron detection, training verification. Built from an SRE perspective. · GitHub
PyPI: Client Challenge
Built this because I kept writing the same debugging snippets over and over — print the gradients, check for dead neurons, compare weights before and after step. Figured others might find it useful too.
Feedback and contributions welcome.