Calculating the Jacobian of gradients w.r.t to true output

Hello,

I’m trying to calculate the Jacobian of the gradients of a model w.r.t. its true output.

Basically, I need to calculate this:
image

It is mentioned in Mo et al., 2021 ([2010.08762] Layer-wise Characterization of Latent Information Leakage in Federated Learning), in chapter 2. There is no official implementation, so I don’t know how to approach this.

Could anyone please help out?

This can be done with torch.autograd.grad, although I don’t know how it compares performance wise to torch.func

Those functions calculate jacobian of a sequence of tensors w.r.t another sequence of tensors. Two versions, normal and batched. Batched implementation corresponds to the is_grads_batched argument explained in torch.autograd.grad — PyTorch 2.5 documentation, basically its faster but experimental.

from collections.abc import Sequence, Iterable
import torch

def jacobian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
    flat_input = torch.cat([i.reshape(-1) for i in input])
    grad_ouputs = torch.eye(len(flat_input), device=input[0].device, dtype=input[0].dtype)
    jac = []
    for i in range(flat_input.numel()):
        jac.append(torch.autograd.grad(
            flat_input,
            wrt,
            grad_ouputs[i],
            retain_graph=True,
            create_graph=create_graph,
            allow_unused=True,
            is_grads_batched=False,
        ))
    return [torch.stack(z) for z in zip(*jac)]

def batched_jacobian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
    flat_input = torch.cat([i.reshape(-1) for i in input])
    return torch.autograd.grad(
        flat_input,
        wrt,
        torch.eye(len(flat_input), device=input[0].device, dtype=input[0].dtype),
        retain_graph=True,
        create_graph=create_graph,
        allow_unused=True,
        is_grads_batched=True,
    )

Here is an example of calculating the Jacobian of the gradients of a model w.r.t. its output:

import torch
model = torch.nn.Sequential(torch.nn.Linear(2,3), torch.nn.ReLU(), torch.nn.Linear(3,2))

outputs = model(torch.randn(2))

loss = (outputs ** 2).mean()

grad = jacobian([loss], list(model.parameters()), create_graph = True)

jac_wrt_output = jacobian(grad, [outputs])[0]