Hi there. I have a system of partial differential equations that I want to solve numerically, consisting of a Fokker-Planck and a Hamiltonian Jacobi Bellman equation. I have found an approach in the applied maths literature under the name of Deep Galerkin Method, with someone actually providing a tensorflow example of one such system. https://colab.research.google.com/drive/1xqamOTOCw7LRVxCMo1TECGM7st6XeB0H?usp=sharing#scrollTo=DZPFl5hfg4M7
I went about replicating it in pytorch. After some weeks working away at this (I’m a beginner) I got to the point of not encountering any errors in my somewhat convoluted version. However, the results are rubbish. It’s impossibly slow, and the losses are negative and increasing. Alternatively, they show up as nan. Whether they’re nan or negative depends on fiddling with the learning rate and gradient clipping. I’ve stayed as true as possible to the original tensorflow version, only deviating upon reading suggestions by the experts on this forum (gradient clipping and a lower learning rate). These, as far as I can see, are the only differences from the original code. But I’m at an impasse, and I don’t know how to proceed. Here’s my attempt at the code:
import time
import numpy as np
import matplotlib.pyplot as plt
import copy
import torch
import torch.nn as nn
import torch.optim as optim
print(torch.__version__)
# Enable anomaly detection
#torch.autograd.detect_anomaly() #Anomaly Detection has been enabled. This mode will increase the runtime and
#should only be enabled for debugging.
torch.autograd.detect_anomaly(True)
# Set random seeds for reproducibility
seed = 42 # You can use any integer as the seed value
# NumPy
np.random.seed(seed)
# Define the ErgodicSumOfTrigModel class
class ErgodicSumOfTrigModel:
def __init__(self, x_min=0, x_max=1):
self.x_min, self.x_max = x_min, x_max
def Hamiltonian(self, X, mean_field, alpha):
value = 0
value -= 1 / 2 * alpha**2
value += torch.log(mean_field)
value += 2 * np.pi ** 2 * (-torch.sin(2 * np.pi * X) + torch.cos(2 * np.pi * X)**2) - 2 * torch.sin(2 * np.pi * X)
return value
def FP_drift(self, X, mean_field, alpha):
value = -alpha
return value
# Define the neural network models
class NeuralNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, activation):
super(NeuralNetwork, self).__init__()
self.layers = nn.ModuleList([
nn.Linear(input_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, output_dim)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
# Define the DGMSolver class
class DGMSolver:
def __init__(self, problem, p_neuralnet, nu_neuralnet, lambda_neuralnet,
coef_penal_period=1,
coef_penal_normalization=1):
self.pb = problem
# NN
self.p_nn = p_neuralnet
self.nu_nn = nu_neuralnet
self.lambda_nn = lambda_neuralnet
self.coef_penal_period = coef_penal_period
self.coef_penal_normalization = coef_penal_normalization
self.loss_train, self.loss_val = [], []
self.loss_hjb_train, self.loss_hjb_val = [], []
self.loss_kfp_train, self.loss_kfp_val = [], []
self.loss_bc_train, self.loss_bc_val = [], []
self.loss_normal_train, self.loss_normal_val = [], []
self.p_nn_save = []
self.nu_nn_save = []
def sampling(self, n_sample):
X0_sample = torch.rand(n_sample, 1) * (self.pb.x_max - self.pb.x_min) + self.pb.x_min
return X0_sample
def derivatives(self, nn, input):
input = input.float().requires_grad_(True) # Ensure float data type
result_call = nn(input)
dx = torch.autograd.grad(result_call.sum(), input, create_graph=True)[0]
d2dx2 = torch.autograd.grad(dx.sum(), input, create_graph=True)[0]
return result_call, dx, d2dx2
def loss_fn(self, sample):
X = sample
n_samples = X.shape[0]
loss_value = 0
allones_m = torch.ones(X.shape[0], 1).float()
lambdaCoef = self.lambda_nn(allones_m)
p_theta_X, d1_p_theta_X, d2_p_theta_X = self.derivatives(self.p_nn, X)
nu_theta_X, d1_nu_theta_X, d2_nu_theta_X = self.derivatives(self.nu_nn, X)
#### KFP residual
pde_residual_KFP = 0.5 * d2_nu_theta_X - d2_p_theta_X * nu_theta_X - d1_p_theta_X * d1_nu_theta_X
pde_residual_KFP = pde_residual_KFP.mean()
loss_value += pde_residual_KFP
#### HJB residual
pde_residual_HJB = lambdaCoef + 0.5 * d2_p_theta_X - self.pb.Hamiltonian(X, nu_theta_X, d1_p_theta_X) - 1
pde_residual_HJB = pde_residual_HJB.mean()
loss_value += pde_residual_HJB
#### Boundary Condition Residual
tfpt0 = torch.tensor([[self.pb.x_min]], dtype=torch.float32)
tfpt1 = torch.tensor([[self.pb.x_max]], dtype=torch.float32)
p_theta_at_0, d1_p_theta_at_0, d2_p_theta_at_0 = self.derivatives(self.p_nn, tfpt0)
p_theta_at_1, d1_p_theta_at_1, d2_p_theta_at_1 = self.derivatives(self.p_nn, tfpt1)
nu_theta_at_0, d1_nu_theta_at_0, d2_nu_theta_at_0 = self.derivatives(self.nu_nn, tfpt0)
nu_theta_at_1, d1_nu_theta_at_1, d2_nu_theta_at_1 = self.derivatives(self.nu_nn, tfpt1)
bc_residual_period_p = ((p_theta_at_0 - p_theta_at_1)**2).mean() + ((d1_p_theta_at_0 - d1_p_theta_at_1)**2).mean()
bc_residual_period_nu = ((nu_theta_at_0 - nu_theta_at_1)**2).mean() + ((d1_nu_theta_at_0 - d1_nu_theta_at_1)**2).mean()
bc_residual_period = bc_residual_period_p + bc_residual_period_nu
loss_value += self.coef_penal_period * bc_residual_period
#### Normalization Residual
bc_residual_normalization_p = (p_theta_X.mean() - 0.0)**2
bc_residual_normalization_nu = (nu_theta_X.mean() - 1.0)**2
bc_residual_normalization = bc_residual_normalization_p + bc_residual_normalization_nu
loss_value += self.coef_penal_normalization * bc_residual_normalization
return loss_value, pde_residual_KFP, pde_residual_HJB, bc_residual_period, bc_residual_normalization
def one_step_grad(self, sample, optimizer, max_norm):
optimizer.zero_grad()
loss = self.loss_fn(sample)[0]
loss.backward()
torch.nn.utils.clip_grad_norm_(self.p_nn.parameters(), max_norm)
torch.nn.utils.clip_grad_norm_(self.nu_nn.parameters(), max_norm)
torch.nn.utils.clip_grad_norm_(self.lambda_nn.parameters(), max_norm)
optimizer.step()
return loss
def train(self, optimizer, n_iterations, n_sample_train, n_sample_validation, max_norm):
validation_sample = self.sampling(n_sample_validation) # fix a validation sample
start_time = time.time()
# SGD ITERATIONS
for step in range(0, n_iterations + 1):
training_sample = self.sampling(n_sample_train)
loss_train = self.one_step_grad(training_sample, optimizer, max_norm)
if isinstance(loss_train, tuple):
loss_value, kfp, hjb, period, normal = loss_train
else:
loss_value = loss_train
self.loss_train.append(loss_value)
if isinstance(loss_train, tuple):
self.loss_hjb_train.append(hjb)
self.loss_kfp_train.append(kfp)
self.loss_bc_train.append(period)
self.loss_normal_train.append(normal)
if step % 1000 == 0:
loss_val = self.loss_fn(validation_sample)[0]
if isinstance(loss_val, tuple):
loss_value, kfp, hjb, period, normal = loss_val
else:
loss_value = loss_val
self.loss_val.append(loss_value)
if isinstance(loss_val, tuple):
self.loss_hjb_val.append(hjb)
self.loss_kfp_val.append(kfp)
self.loss_bc_val.append(period)
self.loss_normal_val.append(normal)
if step % 1000 == 0:
print("iteration = {}, \t\t loss = {}, \t\t total time = {}".format(step, loss_value, time.time() - start_time))
if step % 10000 == 0:
self.p_nn_save.append(copy.deepcopy(self.p_nn))
self.nu_nn_save.append(copy.deepcopy(self.nu_nn))
return self.p_nn, self.nu_nn, self.lambda_nn
# Create an instance of the ErgodicSumOfTrigModel
pb = ErgodicSumOfTrigModel()
# Define the neural network models
input_dim = 1
hidden_dim = 100
output_dim = 1
activation = nn.Sigmoid
p_nn = NeuralNetwork(input_dim, hidden_dim, output_dim, activation)
nu_nn = NeuralNetwork(input_dim, hidden_dim, output_dim, activation)
lambda_nn = nn.Linear(1, 1, bias=False)
# Create an instance of DGMSolver
mysolver = DGMSolver(problem=pb, p_neuralnet=p_nn, nu_neuralnet=nu_nn, lambda_neuralnet=lambda_nn)
adam_optimizer = optim.Adam(list(p_nn.parameters()) + list(nu_nn.parameters()) + list(lambda_nn.parameters()), lr=1e-9)
# Set the max_norm value here
max_norm = 5
# Train the models
p_nn, nu_nn, lambda_nn = mysolver.train(optimizer=adam_optimizer,
n_iterations=200000,
n_sample_train=2048,
n_sample_validation=2048,
max_norm=1.0)
# Plot the loss curves
loss_train_plot = mysolver.loss_train
loss_val_plot = mysolver.loss_val
fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Total Loss vs Iterations")
ax.legend()
ax.set_yscale('log')
# Plot the loss curves
loss_train_plot = mysolver.loss_train
loss_val_plot = mysolver.loss_val
fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Total Loss vs Iterations")
ax.legend()
ax.set_yscale('log')
# Plot other loss curves (HJB, KFP, Boundary Condition, Normalization)
# ... (similar to the previous code)
loss_train_plot = mysolver.loss_hjb_train#[10:]
loss_val_plot = mysolver.loss_hjb_val#[10:]
fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("HJB Loss vs Iterations")
ax.legend()
ax.set_yscale('log')
loss_train_plot = mysolver.loss_kfp_train#[10:]
loss_val_plot = mysolver.loss_kfp_val#[10:]
fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("KFP Loss vs Iterations")
ax.legend()
ax.set_yscale('log')
loss_train_plot = mysolver.loss_bc_train#[10:]
loss_val_plot = mysolver.loss_bc_val#[10:]
fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Boundary Condition Loss vs Iterations")
ax.legend()
ax.set_yscale('log')
loss_train_plot = mysolver.loss_normal_train#[10:]
loss_val_plot = mysolver.loss_normal_val#[10:]
fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Normalization Condition Loss vs Iterations")
ax.legend()
ax.set_yscale('log')
# Plot the final neural network solutions
Xspace_points = 100
Xspace = np.linspace(pb.x_min, pb.x_max, Xspace_points)
benchmark_p_value = np.sin(2 * np.pi * Xspace)
benchmark_nu_value = np.exp(2 * benchmark_p_value) / (np.sum(np.exp(2 * benchmark_p_value)) * 1/Xspace_points)
Xspace_tensor = torch.tensor(Xspace, dtype=torch.float32).view(-1, 1)
fig, ax = plt.subplots(1, 2, figsize=(20, 8))
indices = [0, 1, 2, len(mysolver.p_nn_save)//2]
for i in indices:
p_nn_iter = mysolver.p_nn_save[i]
nu_nn_iter = mysolver.nu_nn_save[i]
p_value = p_nn_iter(Xspace_tensor).detach().numpy()
nu_value = nu_nn_iter(Xspace_tensor).detach().numpy()
ax[0].plot(Xspace, p_value, label=f"Checkpoint {i}", lw=2)
ax[1].plot(Xspace, nu_value, label=f"Checkpoint {i}", lw=2)
terminal_p_value = p_nn(Xspace_tensor).detach().numpy()
terminal_nu_value = nu_nn(Xspace_tensor).detach().numpy()
ax[0].plot(Xspace, terminal_p_value, label="Last Iterate", lw=2)
ax[1].plot(Xspace, terminal_nu_value, label="Last Iterate", lw=2)
ax[0].plot(Xspace, benchmark_p_value, label="Benchmark Solution", lw=2, c='r', ls='--')
ax[1].plot(Xspace, benchmark_nu_value, label="Benchmark Solution", lw=2, c='r', ls='--')
ax[0].set_xlabel("X")
ax[0].set_ylabel("Value")
ax[0].set_title("Value function")
ax[1].set_xlabel("X")
ax[1].set_ylabel("Value")
ax[1].set_title("Distribution")
ax[0].legend()
ax[1].legend()
plt.show()
I’d really appreciate some insights!