Second derivatives of a function of two variables

My model takes two inputs x and y and returns T(x, y).
How can I find the second derivatives of d_2_T / d_x_x and d_2_T / d_y_y?

Now I managed to find only the first derivatives in this way:

# g - tensor of input values
T = NN(g)
T_x_y = autograd.grad(T,g,torch.ones([g.shape[0], 1]).to(device), retain_graph=True, create_graph=True)[0]

Hi Mikhail!

Continue on with the code you have. Your call to autograd.grad() computes
the first derivative, but it also constructs the computation graph for the
computation of the first derivative. So you can simply call autograd.grad()
a second time to compute the second derivative.

Here is an illustration in the spirit of your code:

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> g = torch.tensor ([5.0, 7.0], requires_grad = True)
>>>
>>> T = g[0]**2 * g[1]**3
>>> T_x_y = torch.autograd.grad (T, g, retain_graph=True, create_graph=True)[0]   # creates graph of first derivative
>>>
>>> T_d_xx = torch.autograd.grad (T_x_y[0], g, retain_graph = True)[0][0]         # compute second derivative (and save graph for T_d_yy computation)
>>> T_d_yy = torch.autograd.grad (T_x_y[1], g)[0][1]                              # compute second derivative
>>>
>>> T_d_xx_B = 2 * g[1]**3        # check second derivative "by hand"
>>> T_d_yy_B = g[0]**2 * 6*g[1]   # check second derivative "by hand"
>>>
>>> T_d_xx
tensor(686.)
>>> T_d_xx_B
tensor(686., grad_fn=<MulBackward0>)
>>> T_d_yy
tensor(1050.)
>>> T_d_yy_B
tensor(1050., grad_fn=<MulBackward0>)

Best.

K. Frank

1 Like

It has been a while since the last reply given in this thread. However, I have got a similar question and associated error, and thought it would be more appropriate to post it here. In PyTorch version 2.1.0+cu118 I have set up two different ways, Code I and Code II, to calculate second derivatives of a multivariable function. How can I successfully carry out it as in Code II?

Code I:

def f(x, t):

  return (3 * x ** 3) * t + (t ** 3) * x

x = torch.tensor([2.06, 5.7], requires_grad=True)
t = torch.tensor([1.5, 0.33], requires_grad=True)

y = f(x, t)

dy_dt = torch.autograd.grad(y.sum(), t, create_graph=True)[0]
dy_dx = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
d2y_dx2 = torch.autograd.grad(dy_dx.sum(), x)[0]

loss = dy_dt + dy_dx - 0.01 * d2y_dx2

print(loss)

Output I:

tensor([100.2378, 653.6338], grad_fn=<SubBackward0>)

Code II:

x = torch.tensor([2.06, 5.7], requires_grad=True)
t = torch.tensor([1.5, 0.33], requires_grad=True)

model = torch.nn.Linear(2, 1)

var_input = torch.stack([x, t], dim=1)

u = model(var_input)

du_dt = torch.autograd.grad(u.sum(), t, create_graph=True)[0]
du_dx = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
d2u_dx2 = torch.autograd.grad(du_dx.sum(), x)[0]

loss = du_dt + du_dx - 0.01 * d2u_dx2

print(loss)

Output II:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-40-bee5e3fcde20> in <cell line: 12>()
     10 du_dt = torch.autograd.grad(u.sum(), t, create_graph=True)[0]
     11 du_dx = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
---> 12 d2u_dx2 = torch.autograd.grad(du_dx.sum(), x)[0]
     13 
     14 loss = du_dt + du_dx - 0.01 * d2u_dx2

/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)
    392         )
    393     else:
--> 394         result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    395             t_outputs,
    396             grad_outputs_,

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Hi Burak!

In your second case, u is linear in x (and in t). Therefore du_dx is a
constant with respect to x, which is to say independent of x. Autograd
recognizes this and, for efficiency, removes x from du_dx’s computation
graph (at least in effect).

To see what is going on, try var_input = torch.stack([x**2, t], dim=1).
u will no longer be linear in x and the second derivative of u with respect
to x will no longer vanish. Also setting allow_unused = True will let you
use autograd to compute zero for the second derivative in the linear case.

As an aside, it probably would have been better to post this question in a
new thread, rather than resurrecting a zombie.

Best.

K. Frank

1 Like

Thank you very much! This reply really made me realize that the simple linear model I gave in my question for demonstration is the cause of my issue. Since the intended model I have in hand is in fact nonlinear. In short, I have created an imaginary problem on my own.