How to calculate a jacobian for an entire batch

I have a network which takes a vector size 10 and returns a vector size 20.
I want to calculate the jacobian for the output of the network.
I used the code

torch.autograd.functional.jacobian(func=network, inputs=x)
to calculate it and it worked, I get the correct matrix size 20 * 10
however when I try to do it over an entire batch (lets say, size 40) I get way over what I wanted (a matrix size 40 * 20 * 40 * 10

I believe the jacobian calculate the gradient of every batch w.r.t. to every input (ignoring the separation between batches)

How can I get the function to treat those batches as separate without using a for loop and separating them myself?

I think I can extract the jacobian by batches like so

indicies = torch.arange(x.shape[0])
jacobian = jacobian[indicies, :, indicies, :]

Is this correct? is there more efficient way to prevent pytorch calculating unnecessary derivitives?

Hi Yedidya!

If I understand what you want to do, I believe that @AlphaBetaGamma96’s
post that uses vmap() and jacrev() will give you your “batch jacobian:”

Best.

K. Frank

2 Likes

This solution indeed seems far more efficient than mine. however I must admit I didnt get how it works.
what is the difference between the jacobian function and the jacrev function? also how the vmap function works?
Thanks anyway, you cut my performance in 70%

Hi @Yedidya_kfir, and @KFrank!

The way that the torch.func namespace works is by creating a so-called functionalized version of your model that is now a pure function of only inputs and outputs (i.e. no other non-local variables). This is why your network parameters are now an input (and not an internal variable like before).

The jacrev function is torch.func’s version of torch.autograd.functional.jacobian and allows for the computation of reverse-mode jacobian calculations. This will allow you compute the jacobian of a functionalized module within the torch.func namespace.

Now, why did I suggest to use the torch.func namespace and not just torch.autograd.functional.jacobian?

Well, we can compute all samples in parallel with torch.func.vmap. What this function does is it redefines your function to work for a single sample, and then uses this transformation to compute all samples in parallel. This is where the speed-up occurs, instead of calling the autograd engine N times (for N samples) via a for-loop, you effectively only call the autograd engine once in a single vectorized call and compute all samples in parallel.

There are limitations to the torch.func namespace, so I’ve left a link to the docs here.