# How to estimate the gradient of an argmin loss

Suppose we have a neural network f_\theta(x), where x is the input and \theta is the network’s parameters.

For each \theta, we can minimize f_\theta(x) w.r.t. x and obtain the minimum point x*(\theta):=argmin_x f_\theta(x).

My question is: how to estimate the gradient: dx*(\theta) / d\theta ?

I did some surveys and found this gradient (in fact any gradient) can be estimated by the finite difference technique. However, it is rather slow when $\theta$ includes many parameters, e.g., f_\theta is a large neural network.

Any possible approach is welcomed.

Hi Med!

Use calculus to expand f to second order around x_star and theta. Solve
for x = argmin f (x) as a function of theta in the context of this second-order
expansion and then compute the gradient of this x with respect to theta. This
will give you an analytic result for the desired gradient in terms of second-order
partial derivatives of f (x; theta) with respect to x and theta.

Then use autograd.grad() twice in a row to compute numerically the required
partial derivatives for your specific network, f, and compute from them the gradient
of x_star with respect to the parameters of your network, the theta.

f (x_star + delta_x; theta + delta_theta) is given, to second order, by:

f (x_star; theta)  +
delta_x * d_f/d_x  +  delta_theta * d_f/d_theta  +
1/2 * delta_x^2 * d^2_f/d_x^2  +  delta_x * delta_theta * d^2_f/d_x d_theta + 1/2 delta_theta^2 * d^2_f/d_theta^2


where for the theta terms * is understood to denote the appropriate
matrix-vector multiplication.

Keeping only the terms that depend on x and recognizing that d_f/d_x is zero
at x_star (the minimum of f (x)), we get

f (x_star + delta_x; theta + delta_theta) ~
1/2 * delta_x^2 * d^2_f/d_x^2  +  delta_x * delta_theta * d^2_f/d_x d_theta


For a small change in theta, delta_theta, the new argmin of f occurs when
d_f (delta_x) / d_delta_x is zero. Setting this derivative of the above expression
to zero and solving for delta_x we get:

delta_x = -delta_theta * d^2_f/d_x d_theta / d^2_f/d_x^2


This delta_x is linear in delta_theta, so its gradient with respect to delta_theta
is simply -d^2_f/d_x d_theta / d^2_f/d_x^2 and this is the desired gradient of
x_star = argmin_x f (x; theta) with respect to theta.

Here is a pytorch script that performs this computation for a specific concrete
network:

import torch
print (torch.__version__)

_ = torch.manual_seed (2023)

f = torch.nn.Sequential (   # network f, with fixed randomly initialized parameters (theta)
torch.nn.Linear (1, 3),
torch.nn.Tanh(),
torch.nn.Linear (3, 1, bias = False)
)

x = torch.tensor ([0.0], requires_grad = True)   # vary x to minimize f(x)

# use pytorch optimizer to find argmin of f(x)
opt = torch.optim.Adam ([x], lr = 0.5)
for  i in range (350):
l = f (x)
l.backward()
opt.step()

# x is now x_star, the argmin of f(x) (for fixed theta)
x_star = x.detach().clone()   # save a copy for later
print ('x_star:', x_star.item())
print ('l:', l.item())

# compute second partials, d^2_f / d_x^2 an d^2_f / d_x d_theta, evaluated at x = x_star

l = f (x)
d_x = torch.autograd.grad (l, x, retain_graph=True, create_graph=True)[0]   # first derivative of f with respect to x
print ('d_x:', d_x)                                                         # d_x is zero because x_star is a stationary point
d2_x = torch.autograd.grad (d_x, x, retain_graph = True)[0]                 # second derivative of f with respect to x
print ('d2_x:', d2_x)                                                       # d2_x is positive because that stationary point is a minimum

print ('d_x_d_th:')
for  t in d_x_d_th:
print (t)

d_x_star_d_weight_1 = -d_x_d_th[0] / d2_x                                   # gradient with respect to weight of first Linear
d_x_star_d_bias_1 = -d_x_d_th[1] / d2_x                                     # gradient with respect to bias of first Linear
d_x_star_d_weight_2 = -d_x_d_th[2] / d2_x                                   # gradient with respect to weight of second Linear

print ('d_x_star_d_theta:')
print (d_x_star_d_weight_1)
print (d_x_star_d_bias_1)
print (d_x_star_d_weight_2)

# numerically check one value of gradient
delta = 1.e-4
f[0].weight[1, 0] += delta

# recompute x_star for perturbed theta
for  i in range (125):
l = f (x)
l.backward()
opt.step()

grad_num = (x - x_star) / delta



And here is its output:

1.13.1
x_star: 4.385826110839844
l: -0.6181716918945312
d2_x: tensor([0.0069])
d_x_d_th:
tensor([[-0.0194],
[ 0.0611],
[ 0.0005]])
tensor([0.0115, 0.0190, 0.0001])
tensor([[-0.0941,  0.0303,  0.0002]])
d_x_star_d_theta:
tensor([[ 2.8206],
[-8.8939],
[-0.0791]])
tensor([-1.6800, -2.7720, -0.0207])
tensor([[13.7074, -4.4159, -0.0283]])


Best.

K. Frank

2 Likes

I notice you jitter the \theta to \theta+\delta, in order to estimate the gradient w.r.t. \theta.

What is the difference between your method and the finite difference technique? Why do you need to expand f to the second derivative at first?

The finite difference technique directly jitters \theta to \theta+\delta and minimizes the f_(\theta+\delta) to obtain x*_new, then estimates the gradient via x*_old - x*_new / \delta.

Hi Med!

Just to be clear, I did two separate things. The main computation used the
second-order expansion of f (x;, theta), plus autograd’s ability to compute
second derivatives, to compute the gradient of x_star with respect to theta.

There are no finite differences here – neither in the expression for that gradient
nor in autograd’s computation of the second partials.

Then, separately, I performed a numerical finite-difference computation (“jitter
the \theta to \theta+\delta”) of just one element of the full gradient, that is, I
estimated the derivative of x_star with respect to just one element of one
of the parameter tensors. This was just a numerical cross-check to show that
the first computation was correct.

As noted above, the main computation has nothing to do with finite differences,
while the second computation – the numerical cross-check – is indeed a standard
finite-difference estimate of the derivative.

x_star – being the argmin – sits at the bottom of the “bowl” defined by f (x).
The second derivative of f (x) with respect to x tells us the curvature of the
bottom of the bowl – that is, is the bowl relatively flat or more steep?

The second-order mixed partials of f (x; theta) with respect to x and theta
give us information about how changes in theta change the bowl, and, in
particular, how they change the location of the bottom of the bowl (which is to
say, how they change x_star).

I think about it like this: The mixed partials tell us how hard changes in theta
“push” on x_star, while the second derivative with respect to x tells us how
firmly x_star resists that push – a steeper bowl means that x_star won’t
move as far in response to that push, while a flatter bowl means that x_star
will move farther. Combined together, the two tell us how far x_star moves
when theta changes, or, in more mathematical language, what the gradient
of x_star with respect to theta is.

Best.

K. Frank

Thanks for the explanation. Problem solved!

d_x_star_d_weight_1 = -d_x_d_th[0] / d2_x                                   # gradient with respect to weight of first Linear
d_x_star_d_bias_1 = -d_x_d_th[1] / d2_x                                     # gradient with respect to bias of first Linear
d_x_star_d_weight_2 = -d_x_d_th[2] / d2_x                                   # gradient with respect to weight of second Linear


When you involve with multi-dimensional inputs of x, the division ... / d2_x will not work. You should use the inverse of a Jacobian matrix here.

FYI, you could try out torchopt, see the Implicit Gradient Differentiation documentation for more details.

1 Like

Hi, Frank.

I tried your code and was able to reproduce the result for random seed 2023.

However, when I set the seed to 456, the analytical gradient and numerical one is quite different.

Any idea why this could happen?

My code:

import torch
print (torch.__version__)

_ = torch.manual_seed (456)

f = torch.nn.Sequential (   # network f, with fixed randomly initialized parameters (theta)
torch.nn.Linear (1, 3),
torch.nn.Tanh(),
torch.nn.Linear (3, 1, bias = False)
)

x = torch.tensor ([0.0], requires_grad = True)   # vary x to minimize f(x)

# use pytorch optimizer to find argmin of f(x)
opt = torch.optim.Adam ([x], lr = 0.5)
for  i in range (350):
l = f (x)
l.backward()
opt.step()

# x is now x_star, the argmin of f(x) (for fixed theta)
x_star = x.detach().clone()   # save a copy for later
print ('x_star:', x_star.item())
print ('l:', l.item())

# compute second partials, d^2_f / d_x^2 an d^2_f / d_x d_theta, evaluated at x = x_star

l = f (x)
d_x = torch.autograd.grad (l, x, retain_graph=True, create_graph=True)[0]   # first derivative of f with respect to x
print ('d_x:', d_x)                                                         # d_x is zero because x_star is a stationary point
d2_x = torch.autograd.grad (d_x, x, retain_graph = True)[0]                 # second derivative of f with respect to x
print ('d2_x:', d2_x)                                                       # d2_x is positive because that stationary point is a minimum

print ('d_x_d_th:')
for  t in d_x_d_th:
print (t)

d_x_star_d_weight_1 = -d_x_d_th[0] / d2_x                                   # gradient with respect to weight of first Linear
d_x_star_d_bias_1 = -d_x_d_th[1] / d2_x                                     # gradient with respect to bias of first Linear
d_x_star_d_weight_2 = -d_x_d_th[2] / d2_x                                   # gradient with respect to weight of second Linear

print ('d_x_star_d_theta:')
print (d_x_star_d_weight_1)
print (d_x_star_d_bias_1)
print (d_x_star_d_weight_2)

# numerically check one value of gradient
delta = 1.e-4
f[0].weight[1, 0] += delta

# recompute x_star for perturbed theta
for  i in range (125):
l = f (x)
l.backward()
opt.step()

grad_num = (x - x_star) / delta



My output:

1.10.0
x_star: -12.196703910827637
l: -0.7395662069320679
d2_x: tensor([0.0003])
d_x_d_th:
tensor([[-0.0081],
[ 0.0000],
[ 0.0002]])
tensor([ 7.4753e-04, -0.0000e+00, -2.0273e-05])
tensor([[6.7220e-04, 0.0000e+00, 2.5887e-05]])
d_x_star_d_theta:
tensor([[30.9327],
[-0.0000],
[-0.8667]])
tensor([-2.8593,  0.0000,  0.0775])
tensor([[-2.5712, -0.0000, -0.0990]])


Hello, Frank and Xuehai:

I implement a toy model to estimate dx_dtheta where x is a high-dimensional vector.

Specifically, I use dx_dtheta = - H_f[x]^-1 * d2f_fxdtheta, where H_f[x]^-1 is the inverse Hessian and the d2f_fxdtheta is the mixed partial.

However, the dx_dtheta I estimated is very stange. They are about 10^15 large.
As a contrast, the result estimated via finite difference is about 1.

My codes and ouputs are attached below. Any idea which part is wrong?

Codes:

import torch
_ = torch.manual_seed(123)
torch.set_default_dtype(torch.float64)

print(torch.__version__)

f = torch.nn.Sequential (   # network f, with fixed randomly initialized parameters (theta)
torch.nn.Linear (3, 2),
torch.nn.Tanh(),
torch.nn.Linear (2, 1, bias = False)
)

x = torch.tensor ([0.0, 0.0, 0.0], requires_grad = True)   # vary x to minimize f(x)

# use pytorch optimizer to find argmin of f(x)
opt = torch.optim.Adam ([x], lr = 0.5)
for  i in range (350):
l = f (x)
l.backward()
opt.step()

# x is now x_star, the argmin of f(x) (for fixed theta)
x_star = x.detach().clone()   # save a copy for later
print ('x_star:', x_star)
print ('l:', l.item())

# compute second partials, d^2_f / d_x^2 an d^2_f / d_x d_theta, evaluated at x = x_star
l = f (x)
d_x = torch.autograd.grad (l, x, retain_graph=True, create_graph=True)[0]   # first derivative of f with respect to x
print ('d_x:', d_x)
print ('d2_x:', d2_x)

# only look at the first fc layer's weight
theta = list(f.parameters())[0]

d_f_d_x_d_th = torch.stack([torch.autograd.grad(d_x[ind], theta, retain_graph=True, create_graph=True)[0] for ind in range(3)]).reshape(*x.shape,*theta.shape)                        # mixed second partials
print ('d_f_d_x_d_th:', d_f_d_x_d_th)

# use the IFT here
d_x_d_th = - torch.matmul(torch.linalg.inv(d2_x), d_f_d_x_d_th.reshape(3,-1))
print('d_x_d_th:', d_x_d_th)

# numerically check one value of gradient
# set delta to 1e-6, 1e-8, 1e-10 gives the same numerical values, so it is quite stable

x = torch. Tensor ([0.0,0.0,0.0], requires_grad = True)   # vary x to minimize f(x)
opt = torch.optim.Adam ([x], lr = 0.5)

delta = 1.e-10
f[0].weight[0, 0] += delta

# recompute x_star for perturbed theta
for  i in range (350):
l = f (x)
l.backward()
opt.step()

grad_num = (x - x_star) / delta

print('analycal: ',d_x_d_th[:,0])


Outputs:

1.10.0
x_star: tensor([-10.2784,   9.1450,  -8.9501])
l: -1.2608992210123384
d_x: tensor([ 1.3880e-04, -1.9742e-04,  2.9152e-05], grad_fn=<SqueezeBackward1>)
d2_x: tensor([[ 2.0108e-04,  7.6194e-05, -2.1971e-05],
[ 7.6194e-05,  2.3750e-04, -4.5307e-05],
d_f_d_x_d_th: tensor([[[ 0.0015, -0.0010,  0.0010],
[ 0.0036, -0.0036,  0.0035]],

[[ 0.0043, -0.0035,  0.0038],
[ 0.0003, -0.0007,  0.0002]],

[[-0.0008,  0.0007, -0.0003],
d_x_d_th: tensor([[ 2.0818e+15,  8.7662e+15,  4.9454e+16, -2.3087e+15,
-9.7215e+15,
-5.4843e+16],
[ 8.7662e+15,  3.6913e+16,  2.0824e+17, -9.7215e+15, -4.0935e+16,
-2.3093e+17],
[ 4.9454e+16,  2.0824e+17,  1.1748e+18, -5.4843e+16, -2.3093e+17,
analycal:  tensor([2.0818e+15, 8.7662e+15, 4.9454e+16],
numerial:  tensor([  0.9589, -12.7340,  15.6670], grad_fn=<DivBackward0>)


A reproduction script using torchopt’s implicit gradient:

import torch
import torch.nn as nn

import torchopt

def __init__(self, mlp, x0):
super().__init__()
self.mlp = mlp

def objective(self):
return self.mlp(self.x).squeeze()

def solve(self):
for i in range(1000):
loss = self.objective()
loss.backward(inputs=[self.x])
optimizer.step()

torch.manual_seed(456)

mlp = torch.nn.Sequential(
torch.nn.Linear(1, 3),
torch.nn.Tanh(),
torch.nn.Linear(3, 1, bias=False),
)
x0 = torch.tensor([0.0])

model = ImplicitModel(mlp, x0)
model.solve()

model.x[0].backward()

for name, meta_param in model.named_meta_parameters():


Output for seed 456:

mlp.0.weight.grad: tensor([[13.5392],
[-0.0203],
[-3.2621]])


Output for seed 2023:

mlp.0.weight.grad: tensor([[ 2.6780],
[-5.9816],
[-0.1180]])


Reproduction script for multi-dimension inputs:

import torch
import torch.nn as nn

import torchopt

def __init__(self, mlp, x0):
super().__init__()
self.mlp = mlp

def objective(self):
return self.mlp(self.x).squeeze()

def solve(self):
for i in range(1000):
loss = self.objective()
loss.backward(inputs=[self.x])
optimizer.step()

torch.manual_seed(123)

mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Tanh(),
torch.nn.Linear(2, 1, bias=False),
)
x0 = torch.tensor([0.0, 0.0, 0.0])

for i in range(3):
for param in mlp.parameters():

model = ImplicitModel(mlp, x0)
model.solve()

print(f'Derivative for x[{i}]')
model.x[i].backward()
for name, meta_param in model.named_meta_parameters():
print()


Output for seed 123:

Derivative for x[0]
[ -1.9555,   0.3308,  -1.7174]])

Derivative for x[1]
[-11.7130,   2.0863, -10.2923]])

Derivative for x[2]
[-10.7902,   1.8953,  -9.4800]])


Under the same setting, the finite difference estimates dx_dmlp.0.weight[0] to be [0.1900, -4.9266, 5.9873], while torchopt gives [-1.5724,-9.2159,-7.7948].

So basically, the inverse Hessian, torchopt, and finite difference gives totally different results?

I check, when x=torch. Tensor([0.0]) is a scaler, both the inverse Hessian, torchopt, and finite difference gives the same gradient estimation.

Hi Med!

First note that you haven’t converged to a minimum. Your first derivative,
although smallish, is not zero. (Also your second derivative is relatively
small, so if you were at a minimum, it would be pretty flat.) The formula
I used for the gradient of argmin relies on x_star actually being a minimum
with d_f/d_x equaling zero when evaluated at x_star.

Not having converged to a minimum is not simply a consequence of not
having run enough optimization steps.

Here are the details of what is going on:

The particular “model” I used in my example is not guaranteed to have a
minimum. When the weights of the Linears are initialized starting with a
random seed of 2023, the model does turn out to have a minimum. But
when initialized with a random seed of 456, it does not. As you optimize,
x runs off to -inf and the Tanh “saturates” and becomes flat.

To make this clear, we can add a “stabilizing” term to the model that will
ensure that it has a minimum. So we run the same analysis on the 456
model plus an x**2 term and see that the finite-difference gradient and
“analytic” gradient agree (as well as can be expected).

For completeness, here is the modified script:

import torch
print (torch.__version__)

_ = torch.manual_seed (456)   # different seed causes different random initialization of model

class FStab (torch.nn.Module):
def __init__ (self):
super().__init__()
self.fseq = torch.nn.Sequential (   # network f, with fixed randomly initialized parameters (theta)
torch.nn.Linear (1, 3),
torch.nn.Tanh(),
torch.nn.Linear (3, 1, bias = False)
)
def forward (self, x):                  # add stabilization term in forward
return  self.fseq (x) + x**2 / 16

f = FStab()

x = torch.tensor ([0.0], requires_grad = True)   # vary x to minimize f(x)

# use pytorch optimizer to find argmin of f(x)
opt = torch.optim.Adam ([x], lr = 0.5)
for  i in range (350):
l = f (x)
l.backward()
opt.step()

# x is now x_star, the argmin of f(x) (for fixed theta)
x_star = x.detach().clone()   # save a copy for later
print ('x_star:', x_star.item())
print ('l:', l.item())

# compute second partials, d^2_f / d_x^2 an d^2_f / d_x d_theta, evaluated at x = x_star

l = f (x)
d_x = torch.autograd.grad (l, x, retain_graph=True, create_graph=True)[0]   # first derivative of f with respect to x
print ('d_x:', d_x)                                                         # d_x is zero because x_star is a stationary point
d2_x = torch.autograd.grad (d_x, x, retain_graph = True)[0]                 # second derivative of f with respect to x
print ('d2_x:', d2_x)                                                       # d2_x is positive because that stationary point is a minimum

print ('d_x_d_th:')
for  t in d_x_d_th:
print (t)

d_x_star_d_weight_1 = -d_x_d_th[0] / d2_x                                   # gradient with respect to weight of first Linear
d_x_star_d_bias_1 = -d_x_d_th[1] / d2_x                                     # gradient with respect to bias of first Linear
d_x_star_d_weight_2 = -d_x_d_th[2] / d2_x                                   # gradient with respect to weight of second Linear

print ('d_x_star_d_theta:')
print (d_x_star_d_weight_1)
print (d_x_star_d_bias_1)
print (d_x_star_d_weight_2)

# numerically check one value of gradient
delta = 1.e-4
f.fseq[0].weight[1, 0] += delta

# recompute x_star for perturbed theta
for  i in range (125):
l = f (x)
l.backward()
opt.step()

grad_num = (x - x_star) / delta



And its output:

1.13.1
x_star: -1.146886944770813
l: -0.24911892414093018
d2_x: tensor([0.2937])
d_x_d_th:
tensor([[ 0.6172],
[ 0.1321],
[-0.2920]])
tensor([-0.0671, -0.2309, -0.0733])
tensor([[ 0.3524, -0.2288,  0.4708]])
d_x_star_d_theta:
tensor([[-2.1014],
[-0.4498],
[ 0.9942]])
tensor([0.2285, 0.7860, 0.2496])
tensor([[-1.1997,  0.7790, -1.6029]])

Note that when evaluated at the (successfully-converged) x_star, d_x is
zero and d2_x is comfortably positive.