I want to compute Jacobian matrices using pytorch’s `autograd`. Autograd natively computes Jacobian-vector products, so I’d simple like to pass an identity matrix to obtain the full Jacobian (ie, Jv = JI = J).

One wrinkle: I’d like to implement both standard reverse-mode AD computation for the Jacobian, but also a forward-mode version (which should be faster for most of my applications) using the following trick due to Jamie Townshend:

I’ve actually gotten the former (reverse-mode) working, using a few hits found online and some fairly straightforward work:

``````def rev_jacobian(fxn, x, n_outputs, retain_graph):
"""
the basic idea is to create N copies of the input
and then ask for each of the N dimensions of the
output... this allows us to compute J with pytorch's
jacobian-vector engine
"""

# expand the input, one copy per output dimension
n_outputs = int(n_outputs)
repear_arg = (n_outputs,) + (1,) * len(x.size())
xr = x.repeat(*repear_arg)

# both y and I are shape (n_outputs, n_outputs)
#  checking y shape lets us report something meaningful
y = fxn(xr).view(n_outputs, -1)

if y.size(1) != n_outputs:
raise ValueError('Function `fxn` does not give output '
'compatible with `n_outputs`=%d, size '
'of fxn(x) : %s'
'' % (n_outputs, y.size(1)))
I = torch.eye(n_outputs, device=xr.device)

retain_graph=retain_graph,
create_graph=True,  # for higher order derivatives
)

return J[0]
``````

However, for the forward-mode version I can’t seem to get `autograd` to behave. Somehow, I cannot seem to get it to accept the identity matrix as an argument for `grad_outputs` and NOT return a sum of the columns of the final Jacobian I’d like. But given the success of the reverse-mode implementation, it seems like it should be easy!!

Here’s the current code, with a commented out version of what I’d like to do:

``````def fwd_jacobian(fxn, x, n_outputs, retain_graph):
"""
This implementation is very similar to the above, but with
one twist. To implement a forward-mode AD with rev-mode
calls, we first compute the rev-mode VJP for one vector (v)
then we call d/dv(VJP) `n_outputs` times, one per basis vector,
to obtain the Jacobian.

This should be faster if `n_outputs` > "n_inputs"

References
----------
.[1] https://j-towns.github.io/2017/06/12/A-new-trick.html
(Thanks to Jamie Townsend for this awesome trick!)
"""

n_inputs = int(xd.size(0))

# first, compute *any* VJP
y = fxn(xd.view(1,n_inputs)).view(n_outputs)

if y.size(0) != n_outputs:
raise ValueError('Function `fxn` does not give output '
'compatible with `n_outputs`=%d, size '
'of fxn(x) : %s'
'' % (n_outputs, y.size(0)))

create_graph=True,
retain_graph=retain_graph)[0]
assert vjp.shape == (n_inputs,)

# TODO somehow the repeat trick does not work anymore
#      now that we have to take derivatives wrt v
#      so loop over basis vectors and compose jacobian col by col

I = torch.eye(n_inputs, device=x.device)
J = []
for i in range(n_inputs):
retain_graph=retain_graph,
create_graph=True,  # for higher order derivatives
)
J.append(Ji[0])

``````

Looping over each column of the final Jacobian is slow, of course, and wasteful – I’m computing the Jacobian N times, and throwing most of it away each time. Since this is the bottleneck in my code, it would be sweet to do it all in one go .

Hoping someone with more knowledge of the guts of autograd can advise on the proper use here. Thanks in advance!!

PS. I won’t post just yet for sake of brevity, but if anyone is interested I can post a test case as well to work from.

Hi,

The base implementations here for me are from Adam’s gists: here is the one with full Jacobian and Hessian. And here is the one for the forward mode trick

Thanks for those, but they don’t quite answer the question :). My question is: can I do this without looping over the columns of the Jacobian (which Adam does)?

I’m afraid you cannot. This is a limitation of automatic differentiation… You can only efficiently do vJ or Jv products.

Right. I suppose I was operating under the assumption pytorch’s implementation was flexible enough to accept an identity matrix as “v”. That may be wrong.

So if you compare these two implementations, the first gives significantly faster run times (in my hands) than the second. Factor of 10-100x depending on the case. Perhaps this is just to some skipped CUDA overhead? I am trying to understand this behavior and see if I can exploit the same kind of speedup using the “forward mode trick”.

``````def _rev_jacobian(fxn, x, n_outputs, retain_graph=True):
"""
the basic idea is to create N copies of the input
and then ask for each of the N dimensions of the
output... this allows us to compute J with pytorch's
jacobian-vector engine
"""

# expand the input, one copy per output dimension
n_outputs = int(n_outputs)
repear_arg = (n_outputs,) + (1,) * len(x.size())
xr = x.repeat(*repear_arg)

# both y and I are shape (n_outputs, n_outputs)
#  checking y shape lets us report something meaningful
y = fxn(xr).view(n_outputs, -1)

if y.size(1) != n_outputs:
raise ValueError('Function `fxn` does not give output '
'compatible with `n_outputs`=%d, size '
'of fxn(x) : %s'
'' % (n_outputs, y.size(1)))
I = torch.eye(n_outputs, device=xr.device)

retain_graph=retain_graph,
create_graph=True,  # for higher order derivatives
)

return J[0]

def _rev_jacobian_simple(fxn, x, n_outputs, retain_graph=True):

n_outputs = int(n_outputs)

xd = x.detach()
n_inputs = int(xd.size(0))

y = fxn(xd.view(1,n_inputs)).view(n_outputs)

if y.size(0) != n_outputs:
raise ValueError('Function `fxn` does not give output '
'compatible with `n_outputs`=%d, size '
'of fxn(x) : %s'
'' % (n_outputs, y.size(0)))
I = torch.eye(n_outputs, device=xd.device)

J = torch.zeros([n_outputs, n_inputs], device=xd.device)
for i in range(n_outputs):
retain_graph=retain_graph,
create_graph=True,  # for higher order derivatives
)[0]
return J
``````
1 Like

What if you use the same repeat trick for forward mode? IE, replicate x and v vector and do the same as in reverse mode?

That would be faster but consume much more memory. That might be ok in your case though !

You can do the same for the “forward mode trick” as well I guess. But not sure this will make the Jacobian computation faster than this one.

Tricks on tricks for the win?

@Yaroslav_Bulatov and @albanD thanks. In the end I copied Yaroslav – doing the “repeat” trick twice. I was hoping I would not have to re-compute N reverse passes, but it seems unavoidable.

With code very similar to Yaroslav’s I get a ~5x speed boost using the repeat trick.

And consistent with @albanD 's expectations the faster, non-looping implementations do use much more memory. So some tradeoff. For me speed is more important and I have memory to spare, so I will probably roll with that :).

Thanks to both of you guys!

2 Likes

Instead of using `torch.autograd`, I’ve tried computing the Jacobian “manually”.

``````def manual_jacobian(x, model):
# x -> fc1 -> s1
# s1 -> relu -> a1
# a1 -> fc2 -> s2 (output)

register_hooks(model)
output = model(x)

n = output.size(0)
fc1 = model.fc1
fc2 = model.fc2

# Jacobian s2 (output) -> a1
Ja1 = torch.stack([fc2.weight] * n)

# ReLU derivative a1 -> s1
d = (fc2.data_input > 0).type(fc2.data_input.dtype)

# Jacobian s2 (output) -> s1
Js1 = contract('bij,bj->bij', Ja1, d)

# Jacobian s2 (output) -> x
Jx = contract('bij,jk->bik', Js1, fc1.weight)

return Jx
``````

here is the example code.

By limiting the type of layer (Linear, ReLU, and Conv2d for the example code above), we can know how to compute the Jacobian efficiently without using `torch.autograd` nor the repeat trick.
For example, we can get the Jacobian for ConvNd by using ConvTransposeNd

Benchmark on 1 GPU: my approach (`manual mode`) vs `reverse mode` for a Linear NN

``````bs: 100
n_inputs: 10
n_outputs: 100
loop: 100
-------------
reverse mode: 0.459s
manual mode: 0.195s
(Jx_rev - Jx_man).max(): 0.0
``````

and it’s much faster when it comes to a large number of outputs

``````bs: 100
n_inputs: 10
n_outputs: 1000
loop: 100
-------------
reverse mode: 18.032s
manual mode: 1.884s
(Jx_rev - Jx_man).max(): 0.0
``````

For a CNN

``````bs: 1
input shape: torch.Size([3, 32, 32])
n_outputs: 100
loop: 1
-------------
reverse mode: 0.416s
manual mode: 0.232s
(Jx_rev - Jx_man).max(): 3.725290298461914e-09
``````

To summarize the pros & cons of `manual mode`:

pros

• it can compute Jacobian in a single backpropagation for an arbitrary number of outputs.
• compare to the repeat trick, it doesn’t require any copy of the inputs for forwarding.
• it can evaluate the Jacobian layer by layer, so it’s supposed to work even with a large model, e.g., ResNet-50 for ImageNet classification.

cons

• we need to figure out the way to calculate the Jacobian for each module we need for constructing a network.
• we need to detect the type and the order of the modules in the network.
• it would be more complicated for a network with skip connections.

I am trying to overcome these shortcomings of `manual mode`, and I hope this approach would help.

HI,

You can check this package that uses this trick in a slightly more general way: https://github.com/cybertronai/autograd-lib

Thanks!
But the Jacobian computation in autograd-lib includes the loop along the output dimension.

``````bs: 128
n_inputs: 1000
n_outputs: 1000
loop: 1
-------------
reverse mode: 1.193s
manual mode: 0.238s
``````

@Kazuki_Osawa thanks for sharing, that is cool. Right now I am only using 2dConv / 2dTransposedConv / 2dBatchNorm / Linear layers, so perhaps it would not be too much work to assemble the components for manual Jacobian calculations. I have been changing my architecture frequently, so the ability to have things “just work” is quite nice :).

I will keep this in mind though if I need more speed.

Curious if you benchmarked JAX at all against your manual mode? At some point just implementing stuff in numpy and using JAX is going to be simpler… or going “full manual”

@tjlane
I haven’t tried JAX, but that comparison sounds very interesting! I will also try JAX and put the results at some point.

Thanks!

I tried the Jacobian computation by JAX (`jax.jacfwd` and `jax.jacrev`) on 1 GPU (NVIDIA Tesla V100)

I defined the same Linear NN for JAX using Flax

``````class FlaxNet(flax.nn.Module):

def apply(self, x):
s1 = flax.nn.Dense(x, features=100, bias=False)
a1 = flax.nn.relu(s1)
s2 = flax.nn.Dense(a1, features=1000, bias=False)

return s2
``````

Jacobian by JAX

``````def jax_jacobian(x, model, mode='fwd'):
output = model(x)
if mode == 'fwd':
Jx = jax.jacfwd(model)(x)
else:
Jx = jax.jacrev(model)(x)

return Jx
``````

For simplicity, I just set mini-batch size `bs`= 1.

``````bs: 1
n_inputs: 1000
n_outputs: 1000
loop: 100
-------------
jax.jacfwd: 1.959s
jax.jacrev: 1.240s
reverse mode: 2.564s
manual mode: 0.471s
``````

As I’m a beginner of JAX, I might not be fully utilizing the potential of JAX, but the `manual mode` in PyTorch is still the fastest.
(example code for these runs)

1 Like

Your JAX runs look slow – I’ve tried you code on an oldish CPU and runs significantly faster. A large overhead is being introduced for some reason. I would first try calling the JAX jit method on the jacobian and saving the resulting function to a variable which is then called.

With JIT and jacrev on my CPU (E5-2680 v3 @ 2.50GHz) I get time of 0.393. NB you should call block_until_ready() on last output to ensure asynchronous execution does not cause misleading results.

Thanks for sharing!
Yes, JAX could be much faster with JIT on a CPU. But even with JIT compiler, when the dimension is so large that we can ignore the CPU overhead, the manual mode (`torch.man`) is faster than JAX.

I have updated my code to measure the time with `jax.jit` and `jax.vmap`.

``````jac_fun = jax.jit(jax.jacrev(model))
Jx_fn = jax.jit(jax.vmap(jac_fun, in_axes=(0,)))
``````

On a GPU (NVIDIA Tesla V100),

``````mode: torch.auto
bs: 32
n_inputs: 1000
hidden_ndim: 1000
n_outputs: 1000
n_layers: 3
-------------
loop: 100
device: cuda
torch auto rev: 2.591s
``````
``````mode: torch.man
bs: 32
n_inputs: 1000
hidden_ndim: 1000
n_outputs: 1000
n_layers: 3
-------------
loop: 100
device: cuda
torch manual rev: 0.454s
``````
``````mode: jax
bs: 32
n_inputs: 1000
hidden_ndim: 1000
n_outputs: 1000
n_layers: 3
-------------
loop: 100
device: [GpuDevice(id=0)]
jit(jax.jacrev): 1.024s
``````

This is not about PyTorch vs JAX, but the FLOPS we need to evaluate the Jacobian.
It is the nature of the auto-grad to evaluate the vector-Jacobian product (vjp) or the Jacobian-vector product (jvp), so you need extra computation compared to the manual-mode.

for the reverse-mode, auto-grad requires extra (M_L)(M_1)(M_0) FLOPS.
M_L: output dimension of NN
M_0: input dimension of NN (the first layer)
M_1: output dimension of the first layer