Computing batch Jacobian efficiently

I’m trying to compute Jacobian (and its inverse) of the output of an intermediate layer (block1) with respect to the input to the first layer. The code looks like :

def getInverseJacobian(net2, x):
    # define jacobian matrix
    # x has shape (n_batches X dim of input vector)
    # Take one input point from x and forward it through 1st block 

    jac = torch.zeros(size=(x.shape[1],x.shape[1]))

    y = net2.block1(x)

    for i in range(x.shape[1]):
        jac[i,:] = torch.autograd.grad(y[0][i],x, create_graph=True)[0]



    # Getting inverse of jacobian using Penrose pseudo-inverse
    jac_inverse = torch.pinverse(jac)

    if torch.isnan(jac_inverse).any():
        print('Nan encountered in Jacobian !')
        sys.exit(0)
    
    return jac_inverse

This works well for single data in a batch. How do I convert it to make a Jacobian for complete batch without using loop. This function will be called several times in the training and loop would not be ideal.
Any suggestions? Am I calculating Jacobian efficiently in the first place? [It is accurate though]

1 Like

Hi,

Yes this looks like the right way to do it.
FYI we now have a built in function that does the same thing: https://pytorch.org/docs/stable/autograd.html#torch.autograd.functional.jacobian

There is no better way to compute the jacobian yet I’m afraid. But we’re working on it.

1 Like

Hi, do you mean the in-build function also works on one input point?

By the way, may I ask for an example of https://pytorch.org/docs/stable/autograd.html#torch.autograd.functional.jacobian with respect to the parameters of a network please? I have no clue since the first parameter of jacobian is a function.

Not sure what you mean by “one input point” could you clarify?

For nn.Module, you can check this answer: Get gradient and Jacobian wrt the parameters - #3 by albanD

Hi,

Recently, I met the same problem and tried to do the batch_jacobian operation with for-loops. Although it works, but the runtime is too long. Fianlly, I implemented the batch_jacobian in another way, it is more efficient and the runtime is close to the tf.GradientTape.

def batch_jacobian(func, x, create_graph=False):
  # x in shape (Batch, Length)
  def _func_sum(x):
    return func(x).sum(dim=0)
  return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)
3 Likes

Note that if you’re using the latest version of pytorch, there is a vectorize=True flag for functional.jacobian() that might speed things up in some cases :slight_smile:

I’m unsure what the most efficient implementation is if both my inputs and the outputs are batched. Specifically, if I have inputs of shape [B, n] and func maps to outputs of shape [B, m], then calling

jac = torch.autograd.functional.jacobian(func, inputs, vectorize=True)

returns a tensor of shape [B, n, B, m].

But if there is no interaction between batches, then jac[i, :, j, :] are just zero tensors, and I only really need to compute jac[i, :, :] = jac[i, :, i, :].

Of course, I can just select the appropriate entries of jac, but I’m wondering if this is still the most efficient approach here since a lot of unnecessary gradients are computed. Is there a better way?

Seems to be the same problem as in this thread.

1 Like

Have a look at functorch, that’ll allow you to vectorize over your batch (so you won’t have those zero tensors).

2 Likes

Thank you! That was exactly what I was looking for!

For anyone else wondering, functorch.vmap(functorch.jacref(func))(inputs) does the trick.

2 Likes

jacrev not jacref! :wink:

functorch.vmap(functorch.jacrev(func))(inputs) #vectorized reverse-mode AD

For anyone else wondering, it also supports forward-mode AD too via jacfwd if you need that too!

5 Likes

Thank you so much, I have been trying to solve this problem for a week!

1 Like

Just an update for anyone who reads the thread in the future, as of PyTorch2,the functorch library is now included in pytorch. So you can replace functorch with torch.func, for the most part the syntax is the same except if you have an nn.Module you’ll need to create a ‘functional’ version of your model.

For example,

model = myModel(*args, **kwargs) #our network

from torch.func import vmap, jacrev, functional_call

params = dict(model.named_parameters())
inputs = torch.randn(batch_size, input_size) #random input data

def fmodel(params, inputs): #functional version of model
  return functional_call(model, params, inputs)

result = vmap(jacrev(fmodel, argnums=(1)), in_dims=(None,0))(params, inputs)

The documentation for torch.func can be found here.

Also, if you’re migrating from functorch to torch.func they have a documentation page on the changes between them here.

1 Like