Gradient computation with PyTorch autograd with 1th and 2th order derivatives does not work

I am having a weird issue with PyTorch’s autograd functionality when implementing a custom loss calculation on a second order differential equation. In the code below, predictions of the neural network are checked if they satisfy a second order differential equation. This works fine. However, when I want to calculate the gradient of the loss with respect to the predictions, I get an error indicating that there seems to be no connection between loss and u in the computational graph.

RuntimeError: One of the differentiated Tensors appears to not have
been used in the graph. Set allow_unused=True if this is the desired
behavior.

This doesn’t make sense because the loss is directly dependent and calculated with the prior derivatives that originate from u. Deriving the loss with respect to u_xx and u_t works, deriving to u_x does NOT. We verified that .requires_grad is set to True for all variables (X, u, u_d, u_x, u_t, u_xx).

Why does this happen, and how to fix this?

Main code:

# Ensure X requires gradients
X.requires_grad_(True)

# Get model predictions
u = self.pinn(X)

# Compute first-order gradients (∂u/∂x and ∂u/∂t)
u_d = torch.autograd.grad(
    u,
    X,
    grad_outputs=torch.ones_like(u),
    retain_graph=True,
    create_graph=True,  # Allow higher-order differentiation
)[0]

# Extract derivatives
u_x, u_t = u_d[:, 0], u_d[:, 1]  # ∂u/∂x and ∂u/∂t

# Compute second-order derivative ∂²u/∂x²
u_xx = torch.autograd.grad(
    u_x,
    X,
    grad_outputs=torch.ones_like(u_x),
    retain_graph=True,
    create_graph=True,
)[0][:, 0]

# Diffusion equation (∂u/∂t = κ * ∂²u/∂x²)
loss = nn.functional.mse_loss(u_t, self.kappa * u_xx)

## THIS FAILS
# Compute ∂loss/∂u
loss_u = torch.autograd.grad(
    loss,
    u,
    grad_outputs=torch.ones_like(loss),
    retain_graph=True,
    create_graph=True,
)[0]

# Return error on diffusion equation
return loss

Model:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Sequential                               [1, 1]                    --
├─Linear: 1-1                            [1, 50]                   150
├─Tanh: 1-2                              [1, 50]                   --
├─Linear: 1-3                            [1, 50]                   2,550
├─Tanh: 1-4                              [1, 50]                   --
├─Linear: 1-5                            [1, 50]                   2,550
├─Tanh: 1-6                              [1, 50]                   --
├─Linear: 1-7                            [1, 50]                   2,550
├─Tanh: 1-8                              [1, 50]                   --
├─Linear: 1-9                            [1, 1]                    51
==========================================================================================
Total params: 7,851
Trainable params: 7,851
Non-trainable params: 0
Total mult-adds (M): 0.01
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.03
Estimated Total Size (MB): 0.03
==========================================================================================

What we have already tried:

Reverted to an older PyTorch version (tested on 2.5.0, and 1.13.1). Same issue.

Putting .requires_grad_(True) after every variable assignment. This did not help.

We also tried to replace the tensor slicing by multiplying with zero/one vectors without results. We though this slicing might disturb the computational graph breaking the connection to u.

# Extract derivatives
u_x, u_t = u_d[:, 0], u_d[:, 1]  # ∂u/∂x and ∂u/∂t

# Extract derivatives alternative
u_x = torch.sum(
    torch.reshape(torch.tensor([1, 0], device=u_d.device), [1, -1]) * u_d,
    dim=1,
    keepdim=True,
)
u_t = ...

Thanks for your help!

Hi woutr!

You could also say that the prior derivatives originate from X, rather than from u.

Consider y = x**2. In “ordinary calculus” you would typically say that
dy / dx = 2 * x, so that dy / dx depends on x. You would not typically say
that dy / dx = 2 * sqrt (y), although this is in a sense true. Pytorch’s autograd
(appropriately) takes the “ordinary calculus” view that dy / dx depends on x, but
not directly on y.

(As a blunter example, consider y = x**2 and z = exp (x) (so that x = log (z)).
Although dy / dx = 2 * x could be expressed as dy / dx = 2 * log (z), you
would not expect autograd to know how to differentiate dy / dx with respect to z.
You would have to do that yourself “by hand.”)

Consider this illustrative script based loosely on your code:

import torch
print (torch.__version__)

a = torch.tensor ([1.234], requires_grad = True)
b = a**2

db_da = torch.autograd.grad (b, a, create_graph = True)

d_db_da_da = torch.autograd.grad (db_da, a, create_graph = True)                        # works
print ('d_db_da_da:', d_db_da_da)

# d_db_da_db = torch.autograd.grad (db_da, b, create_graph = True)                      # would fail

d_db_da_db = torch.autograd.grad (db_da, b, create_graph = True, allow_unused = True)   # runs, but None
print ('d_db_da_db:', d_db_da_db)

Here is its output:

2.6.0+cu126
d_db_da_da: (tensor([2.], grad_fn=<MulBackward0>),)
d_db_da_db: (None,)

Without knowing what is inside of self.pinn(), it’s hard to know what you are trying
to do or why.

One thought: Going back to single-variable calculus, you could say that:

d (dy / dx) / dy = (d (dy / dx) / dx)  /  (dy / dx)

where both terms on the right-hand side can be computed with autograd.

Best.

K. Frank

1 Like