Derivative of NN in the loss function


I want to approximate the Value function V(x) of a continuous-time system by a neural network.
For the learning, the loss function have to contain the derivative of the model (V(x)) with respect to the input. How can I implement this in PyTorch?
I added a standard LQR example in the following, but the loss is not decreasing.
I think, grad_V_x no longer has a symbolic dependence on the network weights and thus steepest does not work.
Any help is appreciated!

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# setting seed (for reproducible learning) 
seed = 0
if torch.cuda.is_available():

# model dynamics
A = 1
B = 1

# cost function 
R = 1
Q = 1

# Architecture of the neural network 
model = nn.Sequential(
    nn.Linear(1, 10),
    nn.Linear(10, 10),
    nn.Linear(10, 1)

K=1.1 # stable initial policy 

# loss function and optimizer
loss_fn = nn.MSELoss()  # mean square error
optimizer = optim.Adam(model.parameters(), lr=0.015) 

while l<10: # Policy iteration
    N = 11 # number of samples
    x_0 = np.linspace(-1,1,N)
    x_dot = np.zeros((N,1))   
    r = np.zeros((N,1))

    if l==1: # start with stable initial policy 
        for i in range(N):
            r[i]=Q*x_0[i]**2 + R*(K*x_0[i])**2
    else: # use NN
        for i in range(N):
            grad_V_x = model(x_0_tensor[i]) # prediction
            grad_V_x = grad_V_x.detach().numpy() # convert to numpy
            u = -0.5*(1/R)*B*grad_V_x # calculate optimal input
            r[i]=Q*x_0[i]**2 + R*(u)**2 

    # train NN 
    epoch = 1
    tol = 0.0001
    loss = np.inf

    while loss>tol and epoch<3000: # terminal criterion 
        x_0_tensor = torch.from_numpy(np.expand_dims(x_0,axis=1)).float()
        x_dot_tensor = torch.from_numpy(x_dot).float()

        r_tensor = torch.from_numpy(r).float()

        grad_V_x = torch.autograd.functional.jacobian(model,x_0_tensor)
        loss = loss_fn(r_tensor,-grad_V_x*x_dot_tensor ) # minimize Bellmann error 
        optimizer.zero_grad() # reset gradients
        loss.backward() # calculate gradient 
        optimizer.step() # update weights
        print(f"epoch {epoch} loss {loss}")
        epoch = epoch+1 # count epochs


Hi @xn2402,

I’d recommend looking at the torch.func namespace as that allows for the composition of derivatives within the loss function, although there’s a slight change in the syntax. The documentation can be found here: torch.func — PyTorch 2.0 documentation