Hello!
I am interested in the trace of the jacobian of the score function. Is there any efficient way to get it?
x = ... # (B, N, D)
log_p = dist.log_prob(x) # (B, N)
score_fn = grad(log_p.sum(), x, create_graph=True)[0] # (B, N, D)
I can get it with for loops, but it’s quite slow.
Hi @sherlock.h,
You could have a look at defining a functional approach and using the torch.func
namespace to efficiently compute the jacobian of your function. The docs are here: torch.func.jacrev — PyTorch 2.4 documentation
Thanks for the suggestion!
I ended up going with Hutchinson’s trace estimator (expectation with respect to the Rademacher distribution):
v = torch.randn(M, B, N, D).sign()
tr = (
torch.autograd.grad(
outputs=score_fn,
inputs=x,
grad_outputs=v,
is_grads_batched=True,
)[0]
.mul(v)
.sum(-1)
).mean(-3)