This can be done with torch.autograd.grad, although I don’t know how it compares performance wise to torch.func
Those functions calculate jacobian of a sequence of tensors w.r.t another sequence of tensors. Two versions, normal and batched. Batched implementation corresponds to the is_grads_batched argument explained in torch.autograd.grad — PyTorch 2.5 documentation, basically its faster but experimental.
from collections.abc import Sequence, Iterable
import torch
def jacobian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
flat_input = torch.cat([i.reshape(-1) for i in input])
grad_ouputs = torch.eye(len(flat_input), device=input[0].device, dtype=input[0].dtype)
jac = []
for i in range(flat_input.numel()):
jac.append(torch.autograd.grad(
flat_input,
wrt,
grad_ouputs[i],
retain_graph=True,
create_graph=create_graph,
allow_unused=True,
is_grads_batched=False,
))
return [torch.stack(z) for z in zip(*jac)]
def batched_jacobian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
flat_input = torch.cat([i.reshape(-1) for i in input])
return torch.autograd.grad(
flat_input,
wrt,
torch.eye(len(flat_input), device=input[0].device, dtype=input[0].dtype),
retain_graph=True,
create_graph=create_graph,
allow_unused=True,
is_grads_batched=True,
)
Here is an example of calculating the Jacobian of the gradients of a model w.r.t. its output:
import torch
model = torch.nn.Sequential(torch.nn.Linear(2,3), torch.nn.ReLU(), torch.nn.Linear(3,2))
outputs = model(torch.randn(2))
loss = (outputs ** 2).mean()
grad = jacobian([loss], list(model.parameters()), create_graph = True)
jac_wrt_output = jacobian(grad, [outputs])[0]