Hello!

I am interested in the trace of the jacobian of the score function. Is there any efficient way to get it?

```
x = ... # (B, N, D)
log_p = dist.log_prob(x) # (B, N)
score_fn = grad(log_p.sum(), x, create_graph=True)[0] # (B, N, D)
```

I can get it with for loops, but it’s quite slow.

Hi @sherlock.h,

You could have a look at defining a functional approach and using the `torch.func`

namespace to efficiently compute the jacobian of your function. The docs are here: torch.func.jacrev — PyTorch 2.4 documentation

Thanks for the suggestion!

I ended up going with Hutchinson’s trace estimator (expectation with respect to the Rademacher distribution):

```
v = torch.randn(M, B, N, D).sign()
tr = (
torch.autograd.grad(
outputs=score_fn,
inputs=x,
grad_outputs=v,
is_grads_batched=True,
)[0]
.mul(v)
.sum(-1)
).mean(-3)
```