If I understand correctly you are computing the trace of the hessian (aka the laplacian)? There is a function for computing the hessian.
Since you only need the elements on the diagonal this might compute a lot of elements you don’t need but I still think there might be a chance that this is faster?
Efficient implementations (using vmap to handle vectorization) following that idea have also been discussed here.
If this is still too inefficient (and this might be a moonshot) there are methods for estimating the trace of the hessian but they are more involved.
hope that helps