Optimize objective involving jacobian

Suppose i have two function z = f(x), g(z), and i want to
z1 = f(x1)
z2 = f(x2)
optimize MSE(\gradient(g) (z1), \gradient(g) (z2)).
But when i backward propogate the gradient of parameter of f is all zero.

Here is the code

import torch
from torch import nn
def batch_jacobian(func, x, create_graph=True, strict=True):
    def _func_sum(x):
        return func(x).sum(dim=0)
    return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph, strict=strict).permute(1,0,2)

g = nn.Sequential(nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 3))
x1 = torch.rand((2, 3))
x2 = torch.rand((2, 3))
mseloss = nn.MSELoss()

theta = torch.arange(6).reshape(3, 2) / 10
theta.requires_grad = True
print(theta)
sig = torch.nn.Sigmoid()
def f(x):
    return sig(x @ theta)

opt = torch.optim.Adam([theta] + list(g.parameters()), lr=1e-3)
opt.zero_grad()
loss = mseloss(batch_jacobian(g, f(x1)), batch_jacobian(g, f(x2)))
loss.backward()
opt.step()

And the output of grad of theta is all zero

Hi af!

The problem is that torch.autograd.functional.jacobian() only backpropagates back to
its inputs argument (your f (x1) and f (x2)). It neither knows nor cares that, say, f (x1)
depends on theta, you don’t backpropagate through f (x1), and therefore you never reach
the dependence on theta, so you get no .grad for theta.

I’ve tweaked your script so that the call to f (x) occurs inside of batch_jacobian() and
therefore inside of the call to torch.autograd.functional.jacobian(). Doing so does then
produce .grad for theta.

Here is the tweaked script:

import torch
print (torch.__version__)

torch.manual_seed (2025)

from torch import nn

theta = torch.arange(6).reshape(3, 2) / 10
theta.requires_grad = True
print ('theta = ...')
print (theta)
sig = torch.nn.Sigmoid()
def f(x):
    return sig(x @ theta)

def batch_jacobian(func, x, create_graph=True, strict=True):
    def _func_sum(x):
        # return func(x).sum(dim=0)
        return func (f (x)).sum (dim=0)
    return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph, strict=strict).permute(1,0,2)

g = nn.Sequential(nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 3))
x1 = torch.rand((2, 3))
x2 = torch.rand((2, 3))
mseloss = nn.MSELoss()

opt = torch.optim.Adam([theta] + list(g.parameters()), lr=1e-3)
opt.zero_grad()
# loss = mseloss(batch_jacobian(g, f(x1)), batch_jacobian(g, f(x2)))
loss = mseloss(batch_jacobian(g, x1), batch_jacobian(g, x2))
loss.backward()
opt.step()

print ('loss:', loss)
print ('theta = ...')
print (theta)
print ('theta.grad = ...')
print (theta.grad)

And here is its output:

2.6.0+cu126
theta = ...
tensor([[0.0000, 0.1000],
        [0.2000, 0.3000],
        [0.4000, 0.5000]], requires_grad=True)
loss: tensor(3.6073e-08, grad_fn=<MseLossBackward0>)
theta = ...
tensor([[0.0010, 0.0990],
        [0.2009, 0.2990],
        [0.4009, 0.4990]], requires_grad=True)
theta.grad = ...
tensor([[-2.0444e-07,  4.2059e-07],
        [-1.0772e-07,  2.4320e-07],
        [-1.5494e-07,  3.4528e-07]])

Best.

K. Frank