Is there a library for automated regression testing for machine learning training code?

I want to test my deep learning training code in an automated way, to make sure that refactoring does not change anything about the update step of my training. I guess I also need some kind of metric for coverage of my computational graph.

Is there a library that does that for pytorch or jax? I do not want to test the trained model (there are libraries for that) but the training code.