Computing batch Jacobian efficiently

Just an update for anyone who reads the thread in the future, as of PyTorch2,the functorch library is now included in pytorch. So you can replace functorch with torch.func, for the most part the syntax is the same except if you have an nn.Module you’ll need to create a ‘functional’ version of your model.

For example,

model = myModel(*args, **kwargs) #our network

from torch.func import vmap, jacrev, functional_call

params = dict(model.named_parameters())
inputs = torch.randn(batch_size, input_size) #random input data

def fmodel(params, inputs): #functional version of model
  return functional_call(model, params, inputs)

result = vmap(jacrev(fmodel, argnums=(1)), in_dims=(None,0))(params, inputs)

The documentation for torch.func can be found here.

Also, if you’re migrating from functorch to torch.func they have a documentation page on the changes between them here.

2 Likes