Computing jacobian column-wise with Pytorch, and without loop

Hi @ergias,

In the other discuss link you shared, it referenced my answer of using the functorch library, however, with the release of PyTorch2.0+, the functorch library is now within torch itself in the torch.func namespace.

For the most part things have stayed the same, except instead of functionalizing your net directly you need to call torch.func.functional_call(net, params, *inputs) and pass the parameters of your network as a dictionary.

For example,

net = Net(*args, **kwargs)
from torch.func import functional_call, jacrev, vmap

def fcall(params, x):
  return functional_call(net, params, x)

per_sample_jacobian_wrt_params = vmap(jacrev(fcall, argnums=(0)), in_dims=(None,0))(params, x)

If you want to remove your loop you’ll need to do a simpler trick what I showed above. If you have more outputs to your function than inputs, you may want to use forward-mode AD via torch.func.jacfwd to get better performance.

Also, don’t use the torch.jit() decorator it really won’t do much, it’s best to use torch.compile() instead, but this will only give minimal improvements as well as torch.compile was only released with PyTorch2.0 recently.

Furthermore, be careful with re-wrapping your tensor, e.g. self.A and self.mu as that’ll break any gradient flow for those parameters. You could also directly initialize on the GPU via torch.as_tensor(data, dtype=torch.float, device=self.device), which may be slightly quicker too.