Computing the trace of the jacobian of the score function

Hello!

I am interested in the trace of the jacobian of the score function. Is there any efficient way to get it?

x = ... # (B, N, D)
log_p = dist.log_prob(x) # (B, N)
score_fn = grad(log_p.sum(), x, create_graph=True)[0] # (B, N, D)

I can get it with for loops, but it’s quite slow.

Hi @sherlock.h,

You could have a look at defining a functional approach and using the torch.func namespace to efficiently compute the jacobian of your function. The docs are here: torch.func.jacrev — PyTorch 2.4 documentation

Thanks for the suggestion!

I ended up going with Hutchinson’s trace estimator (expectation with respect to the Rademacher distribution):
image

v = torch.randn(M, B, N, D).sign()
tr = (
    torch.autograd.grad(
        outputs=score_fn,
        inputs=x,
        grad_outputs=v,
        is_grads_batched=True,
    )[0]
    .mul(v)
    .sum(-1)
).mean(-3)