How to record each layer's gradients in each epoch?

How to record each layer’s gradients in each epoch?

You can use register_backward_hook and gather the grads in some shared dict