Hello, Frank and Xuehai:
I implement a toy model to estimate dx_dtheta where x is a high-dimensional vector.
Specifically, I use dx_dtheta = - H_f[x]^-1 * d2f_fxdtheta, where H_f[x]^-1 is the inverse Hessian and the d2f_fxdtheta is the mixed partial.
However, the dx_dtheta I estimated is very stange. They are about 10^15 large.
As a contrast, the result estimated via finite difference is about 1.
My codes and ouputs are attached below. Any idea which part is wrong?
Codes:
import torch
_ = torch.manual_seed(123)
torch.set_default_dtype(torch.float64)
print(torch.__version__)
f = torch.nn.Sequential ( # network f, with fixed randomly initialized parameters (theta)
torch.nn.Linear (3, 2),
torch.nn.Tanh(),
torch.nn.Linear (2, 1, bias = False)
)
x = torch.tensor ([0.0, 0.0, 0.0], requires_grad = True) # vary x to minimize f(x)
# use pytorch optimizer to find argmin of f(x)
opt = torch.optim.Adam ([x], lr = 0.5)
for i in range (350):
l = f (x)
opt.zero_grad()
l.backward()
opt.step()
# x is now x_star, the argmin of f(x) (for fixed theta)
x_star = x.detach().clone() # save a copy for later
print ('x_star:', x_star)
print ('l:', l.item())
# compute second partials, d^2_f / d_x^2 an d^2_f / d_x d_theta, evaluated at x = x_star
x.grad = None
l = f (x)
d_x = torch.autograd.grad (l, x, retain_graph=True, create_graph=True)[0] # first derivative of f with respect to x
print ('d_x:', d_x)
d2_x = torch.stack([torch.autograd.grad (d_x[ind], x, retain_graph=True, create_graph=True)[0] for ind in range(3)])
print ('d2_x:', d2_x)
# only look at the first fc layer's weight
theta = list(f.parameters())[0]
d_f_d_x_d_th = torch.stack([torch.autograd.grad(d_x[ind], theta, retain_graph=True, create_graph=True)[0] for ind in range(3)]).reshape(*x.shape,*theta.shape) # mixed second partials
print ('d_f_d_x_d_th:', d_f_d_x_d_th)
# use the IFT here
d_x_d_th = - torch.matmul(torch.linalg.inv(d2_x), d_f_d_x_d_th.reshape(3,-1))
print('d_x_d_th:', d_x_d_th)
# numerically check one value of gradient
# set delta to 1e-6, 1e-8, 1e-10 gives the same numerical values, so it is quite stable
x = torch. Tensor ([0.0,0.0,0.0], requires_grad = True) # vary x to minimize f(x)
opt = torch.optim.Adam ([x], lr = 0.5)
delta = 1.e-10
with torch.no_grad():
f[0].weight[0, 0] += delta
# recompute x_star for perturbed theta
for i in range (350):
l = f (x)
opt.zero_grad()
l.backward()
opt.step()
grad_num = (x - x_star) / delta
print('analycal: ',d_x_d_th[:,0])
print('numerial: ',grad_num)
Outputs:
1.10.0
x_star: tensor([-10.2784, 9.1450, -8.9501])
l: -1.2608992210123384
d_x: tensor([ 1.3880e-04, -1.9742e-04, 2.9152e-05], grad_fn=<SqueezeBackward1>)
d2_x: tensor([[ 2.0108e-04, 7.6194e-05, -2.1971e-05],
[ 7.6194e-05, 2.3750e-04, -4.5307e-05],
[-2.1971e-05, -4.5307e-05, 8.9560e-06]], grad_fn=<StackBackward0>)
d_f_d_x_d_th: tensor([[[ 0.0015, -0.0010, 0.0010],
[ 0.0036, -0.0036, 0.0035]],
[[ 0.0043, -0.0035, 0.0038],
[ 0.0003, -0.0007, 0.0002]],
[[-0.0008, 0.0007, -0.0003],
[-0.0002, 0.0002, -0.0006]]], grad_fn=<ReshapeAliasBackward0>)
d_x_d_th: tensor([[ 2.0818e+15, 8.7662e+15, 4.9454e+16, -2.3087e+15,
-9.7215e+15,
-5.4843e+16],
[ 8.7662e+15, 3.6913e+16, 2.0824e+17, -9.7215e+15, -4.0935e+16,
-2.3093e+17],
[ 4.9454e+16, 2.0824e+17, 1.1748e+18, -5.4843e+16, -2.3093e+17,
-1.3028e+18]], grad_fn=<NegBackward0>)
analycal: tensor([2.0818e+15, 8.7662e+15, 4.9454e+16],
grad_fn=<SelectBackward0>)
numerial: tensor([ 0.9589, -12.7340, 15.6670], grad_fn=<DivBackward0>)