Specifically, I found that using jacrev would be much slower than using the oracle Jacobian function:

from torch.func import vmap, jacrev
import torch
import time
a = torch.rand(10000, 10000)
def f(x):
return (x ** 2).sum(-1)
def df(x):
return 2 * x
t0 = time.time()
b = df(a)
t1 = time.time()
c = vmap(jacrev(f))(a)
t2= time.time()
assert torch.allclose(b, c)
print(t1 - t0, t2 - t1)

result: 0.10568618774414062 0.9206998348236084

Given that oracle’s Jacobian is readily available in neural networks, I wonder why using jacrev is so much slower? Is there something wrong with me?

Of course, I can actually rewrite each layer of the neural network to obtain the value and Jacobian at the same time, but calculating the Hessian matrix is too troublesome. It would be great if jacrev could be faster.

In the first case, you perform one direct computation – no autograd, no
computation graph, and hence none of the associated overhead. Note,
in the first case, you never compute x ** 2, so if you need that result,
you would have to compute it separately.

In the second case, jacrev() performs the forward pass that computes x ** 2 and then performs the backward pass that computes the gradient
of the forward pass (which happens to be 2 * x). It does this with all the
overhead of the autograd machinery.

If your real-world use case only has one “layer,” then using autograd might
well be overkill – the real benefit of autograd is that it knows how to chain
multiple layers together.

As an aside, the jacobian in your example is really just the simpler case of
a gradient (because after applying vmap() you are performing x ** 2 on
a single row of a at a time, after which .sum (-1) leaves you with a single
scalar value, and the “jacobian” of a scalar-valued function is its gradient).

You can verify this by replacing vmap (jacrev (f)) (a) in your example
with vmap (grad (f)) (a).

Here’s an example script similar to yours that shows the overhead of the
autograd forward / backward pass when computing just the gradient,
rather than the (scalar-function) jacobian:

import torch
print (torch.__version__)
import time
_ = torch.manual_seed (2023)
a = torch.randn (100000000)
def f (x):
return x.sin()
def df (x):
return x.cos()
t0 = time.time()
b = df (a)
t1 = time.time()
c = torch.func.vmap (torch.func.grad (f))(a)
t2 = time.time()
print ('torch.allclose(b, c) =', torch.allclose(b, c))
print(t1 - t0, t2 - t1)

And here is its output:

2.0.0
torch.allclose(b, c) = True
0.07053017616271973 0.4128530025482178