Hi,
I want to approximate the Value function V(x) of a continuous-time system by a neural network.
For the learning, the loss function have to contain the derivative of the model (V(x)) with respect to the input. How can I implement this in PyTorch?
I added a standard LQR example in the following, but the loss is not decreasing.
I think, grad_V_x no longer has a symbolic dependence on the network weights and thus steepest does not work.
Any help is appreciated!
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
# setting seed (for reproducible learning)
seed = 0
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# model dynamics
A = 1
B = 1
# cost function
R = 1
Q = 1
# Architecture of the neural network
model = nn.Sequential(
nn.Linear(1, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 1)
)
K=1.1 # stable initial policy
# loss function and optimizer
loss_fn = nn.MSELoss() # mean square error
optimizer = optim.Adam(model.parameters(), lr=0.015)
l=1
while l<10: # Policy iteration
N = 11 # number of samples
x_0 = np.linspace(-1,1,N)
x_dot = np.zeros((N,1))
r = np.zeros((N,1))
if l==1: # start with stable initial policy
for i in range(N):
x_dot[i]=A*x_0[i]-B*K*x_0[i]
r[i]=Q*x_0[i]**2 + R*(K*x_0[i])**2
else: # use NN
for i in range(N):
grad_V_x = model(x_0_tensor[i]) # prediction
grad_V_x = grad_V_x.detach().numpy() # convert to numpy
u = -0.5*(1/R)*B*grad_V_x # calculate optimal input
x_dot[i]=A*x_0[i]+B*u
r[i]=Q*x_0[i]**2 + R*(u)**2
# train NN
epoch = 1
tol = 0.0001
loss = np.inf
while loss>tol and epoch<3000: # terminal criterion
x_0_tensor = torch.from_numpy(np.expand_dims(x_0,axis=1)).float()
x_0_tensor.requires_grad_()
x_dot_tensor = torch.from_numpy(x_dot).float()
r_tensor = torch.from_numpy(r).float()
grad_V_x = torch.autograd.functional.jacobian(model,x_0_tensor)
grad_V_x.requires_grad_()
loss = loss_fn(r_tensor,-grad_V_x*x_dot_tensor ) # minimize Bellmann error
optimizer.zero_grad() # reset gradients
loss.backward() # calculate gradient
optimizer.step() # update weights
print(f"epoch {epoch} loss {loss}")
epoch = epoch+1 # count epochs
l=l+1
print(l)