Negative or NAN losses when solving a system of PDEs using the Deep Galerkin Method

Hi there. I have a system of partial differential equations that I want to solve numerically, consisting of a Fokker-Planck and a Hamiltonian Jacobi Bellman equation. I have found an approach in the applied maths literature under the name of Deep Galerkin Method, with someone actually providing a tensorflow example of one such system. https://colab.research.google.com/drive/1xqamOTOCw7LRVxCMo1TECGM7st6XeB0H?usp=sharing#scrollTo=DZPFl5hfg4M7
I went about replicating it in pytorch. After some weeks working away at this (I’m a beginner) I got to the point of not encountering any errors in my somewhat convoluted version. However, the results are rubbish. It’s impossibly slow, and the losses are negative and increasing. Alternatively, they show up as nan. Whether they’re nan or negative depends on fiddling with the learning rate and gradient clipping. I’ve stayed as true as possible to the original tensorflow version, only deviating upon reading suggestions by the experts on this forum (gradient clipping and a lower learning rate). These, as far as I can see, are the only differences from the original code. But I’m at an impasse, and I don’t know how to proceed. Here’s my attempt at the code:

import time
import numpy as np
import matplotlib.pyplot as plt
import copy
import torch
import torch.nn as nn
import torch.optim as optim
print(torch.__version__)
# Enable anomaly detection
#torch.autograd.detect_anomaly() #Anomaly Detection has been enabled. This mode will increase the runtime and
#should only be enabled for debugging.
torch.autograd.detect_anomaly(True)

# Set random seeds for reproducibility
seed = 42  # You can use any integer as the seed value

# NumPy
np.random.seed(seed)


# Define the ErgodicSumOfTrigModel class
class ErgodicSumOfTrigModel:
    def __init__(self, x_min=0, x_max=1):
        self.x_min, self.x_max = x_min, x_max

    def Hamiltonian(self, X, mean_field, alpha):
        value = 0
        value -= 1 / 2 * alpha**2
        value += torch.log(mean_field)
        value += 2 * np.pi ** 2 * (-torch.sin(2 * np.pi * X) + torch.cos(2 * np.pi * X)**2) - 2 * torch.sin(2 * np.pi * X)
        return value

    def FP_drift(self, X, mean_field, alpha):
        value = -alpha
        return value

# Define the neural network models
class NeuralNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, activation):
        super(NeuralNetwork, self).__init__()
        self.layers = nn.ModuleList([
            nn.Linear(input_dim, hidden_dim),
            activation(),
            nn.Linear(hidden_dim, hidden_dim),
            activation(),
            nn.Linear(hidden_dim, hidden_dim),
            activation(),
            nn.Linear(hidden_dim, output_dim)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Define the DGMSolver class
class DGMSolver:
    def __init__(self, problem, p_neuralnet, nu_neuralnet, lambda_neuralnet,
                 coef_penal_period=1,
                 coef_penal_normalization=1):
        self.pb = problem
        # NN
        self.p_nn = p_neuralnet
        self.nu_nn = nu_neuralnet
        self.lambda_nn = lambda_neuralnet
        self.coef_penal_period = coef_penal_period
        self.coef_penal_normalization = coef_penal_normalization

        self.loss_train, self.loss_val = [], []
        self.loss_hjb_train, self.loss_hjb_val = [], []
        self.loss_kfp_train, self.loss_kfp_val = [], []
        self.loss_bc_train, self.loss_bc_val = [], []
        self.loss_normal_train, self.loss_normal_val = [], []

        self.p_nn_save = []
        self.nu_nn_save = []

    def sampling(self, n_sample):
        X0_sample = torch.rand(n_sample, 1) * (self.pb.x_max - self.pb.x_min) + self.pb.x_min
        return X0_sample

    def derivatives(self, nn, input):
        input = input.float().requires_grad_(True)  # Ensure float data type
        result_call = nn(input)
        dx = torch.autograd.grad(result_call.sum(), input, create_graph=True)[0]
        d2dx2 = torch.autograd.grad(dx.sum(), input, create_graph=True)[0]
        return result_call, dx, d2dx2

    def loss_fn(self, sample):
        X = sample
        n_samples = X.shape[0]
        loss_value = 0

        allones_m = torch.ones(X.shape[0], 1).float()
        lambdaCoef = self.lambda_nn(allones_m)

        p_theta_X, d1_p_theta_X, d2_p_theta_X = self.derivatives(self.p_nn, X)
        nu_theta_X, d1_nu_theta_X, d2_nu_theta_X = self.derivatives(self.nu_nn, X)

        #### KFP residual
        pde_residual_KFP = 0.5 * d2_nu_theta_X - d2_p_theta_X * nu_theta_X - d1_p_theta_X * d1_nu_theta_X
        pde_residual_KFP = pde_residual_KFP.mean()
        loss_value += pde_residual_KFP

        #### HJB residual
        pde_residual_HJB = lambdaCoef + 0.5 * d2_p_theta_X - self.pb.Hamiltonian(X, nu_theta_X, d1_p_theta_X) - 1
        pde_residual_HJB = pde_residual_HJB.mean()
        loss_value += pde_residual_HJB

        #### Boundary Condition Residual
        tfpt0 = torch.tensor([[self.pb.x_min]], dtype=torch.float32)
        tfpt1 = torch.tensor([[self.pb.x_max]], dtype=torch.float32)

        p_theta_at_0, d1_p_theta_at_0, d2_p_theta_at_0 = self.derivatives(self.p_nn, tfpt0)
        p_theta_at_1, d1_p_theta_at_1, d2_p_theta_at_1 = self.derivatives(self.p_nn, tfpt1)
        nu_theta_at_0, d1_nu_theta_at_0, d2_nu_theta_at_0 = self.derivatives(self.nu_nn, tfpt0)
        nu_theta_at_1, d1_nu_theta_at_1, d2_nu_theta_at_1 = self.derivatives(self.nu_nn, tfpt1)

        bc_residual_period_p = ((p_theta_at_0 - p_theta_at_1)**2).mean() + ((d1_p_theta_at_0 - d1_p_theta_at_1)**2).mean()
        bc_residual_period_nu = ((nu_theta_at_0 - nu_theta_at_1)**2).mean() + ((d1_nu_theta_at_0 - d1_nu_theta_at_1)**2).mean()
        bc_residual_period = bc_residual_period_p + bc_residual_period_nu
        loss_value += self.coef_penal_period * bc_residual_period

        #### Normalization Residual
        bc_residual_normalization_p = (p_theta_X.mean() - 0.0)**2
        bc_residual_normalization_nu = (nu_theta_X.mean() - 1.0)**2
        bc_residual_normalization = bc_residual_normalization_p + bc_residual_normalization_nu
        loss_value += self.coef_penal_normalization * bc_residual_normalization

        return loss_value, pde_residual_KFP, pde_residual_HJB, bc_residual_period, bc_residual_normalization

    def one_step_grad(self, sample, optimizer, max_norm):
        
        optimizer.zero_grad()
        loss = self.loss_fn(sample)[0]
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(self.p_nn.parameters(), max_norm)
        torch.nn.utils.clip_grad_norm_(self.nu_nn.parameters(), max_norm)
        torch.nn.utils.clip_grad_norm_(self.lambda_nn.parameters(), max_norm)
        optimizer.step()
        return loss

    def train(self, optimizer, n_iterations, n_sample_train, n_sample_validation, max_norm):
        validation_sample = self.sampling(n_sample_validation)  # fix a validation sample
        start_time = time.time()
    # SGD ITERATIONS
        for step in range(0, n_iterations + 1):
            training_sample = self.sampling(n_sample_train)
            loss_train = self.one_step_grad(training_sample, optimizer, max_norm)

            if isinstance(loss_train, tuple):
                loss_value, kfp, hjb, period, normal = loss_train
            else:
                loss_value = loss_train

            self.loss_train.append(loss_value)
            if isinstance(loss_train, tuple):
               self.loss_hjb_train.append(hjb)
               self.loss_kfp_train.append(kfp)
               self.loss_bc_train.append(period)
               self.loss_normal_train.append(normal)

            if step % 1000 == 0:
               loss_val = self.loss_fn(validation_sample)[0]

               if isinstance(loss_val, tuple):
                  loss_value, kfp, hjb, period, normal = loss_val
               else:
                  loss_value = loss_val

               self.loss_val.append(loss_value)
               if isinstance(loss_val, tuple):
                  self.loss_hjb_val.append(hjb)
                  self.loss_kfp_val.append(kfp)
                  self.loss_bc_val.append(period)
                  self.loss_normal_val.append(normal)

               if step % 1000 == 0:
                  print("iteration = {}, \t\t loss = {}, \t\t total time = {}".format(step, loss_value, time.time() - start_time))

               if step % 10000 == 0:
                  self.p_nn_save.append(copy.deepcopy(self.p_nn))
                  self.nu_nn_save.append(copy.deepcopy(self.nu_nn))

        return self.p_nn, self.nu_nn, self.lambda_nn



# Create an instance of the ErgodicSumOfTrigModel
pb = ErgodicSumOfTrigModel()

# Define the neural network models
input_dim = 1
hidden_dim = 100
output_dim = 1
activation = nn.Sigmoid

p_nn = NeuralNetwork(input_dim, hidden_dim, output_dim, activation)
nu_nn = NeuralNetwork(input_dim, hidden_dim, output_dim, activation)
lambda_nn = nn.Linear(1, 1, bias=False)

# Create an instance of DGMSolver
mysolver = DGMSolver(problem=pb, p_neuralnet=p_nn, nu_neuralnet=nu_nn, lambda_neuralnet=lambda_nn)
adam_optimizer = optim.Adam(list(p_nn.parameters()) + list(nu_nn.parameters()) + list(lambda_nn.parameters()), lr=1e-9)
# Set the max_norm value here
max_norm = 5

# Train the models
p_nn, nu_nn, lambda_nn = mysolver.train(optimizer=adam_optimizer,
                                        n_iterations=200000,
                                        n_sample_train=2048,
                                        n_sample_validation=2048,
                                        max_norm=1.0)

# Plot the loss curves
loss_train_plot = mysolver.loss_train
loss_val_plot = mysolver.loss_val

fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Total Loss vs Iterations")
ax.legend()
ax.set_yscale('log')

# Plot the loss curves
loss_train_plot = mysolver.loss_train
loss_val_plot = mysolver.loss_val

fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Total Loss vs Iterations")
ax.legend()
ax.set_yscale('log')

# Plot other loss curves (HJB, KFP, Boundary Condition, Normalization)
# ... (similar to the previous code)
loss_train_plot = mysolver.loss_hjb_train#[10:]
loss_val_plot = mysolver.loss_hjb_val#[10:]

fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("HJB Loss vs Iterations")

ax.legend()
ax.set_yscale('log')
loss_train_plot = mysolver.loss_kfp_train#[10:]
loss_val_plot = mysolver.loss_kfp_val#[10:]

fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("KFP Loss vs Iterations")

ax.legend()
ax.set_yscale('log')

loss_train_plot = mysolver.loss_bc_train#[10:]
loss_val_plot = mysolver.loss_bc_val#[10:]

fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Boundary Condition Loss vs Iterations")
ax.legend()
ax.set_yscale('log')

loss_train_plot = mysolver.loss_normal_train#[10:]
loss_val_plot = mysolver.loss_normal_val#[10:]

fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.plot(1000 * np.arange(len(loss_val_plot)), loss_val_plot, label="Validation", c="blue")
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Normalization Condition Loss vs Iterations")

ax.legend()
ax.set_yscale('log')

# Plot the final neural network solutions
Xspace_points = 100
Xspace = np.linspace(pb.x_min, pb.x_max, Xspace_points)
benchmark_p_value = np.sin(2 * np.pi * Xspace)
benchmark_nu_value = np.exp(2 * benchmark_p_value) / (np.sum(np.exp(2 * benchmark_p_value)) * 1/Xspace_points)

Xspace_tensor = torch.tensor(Xspace, dtype=torch.float32).view(-1, 1)
fig, ax = plt.subplots(1, 2, figsize=(20, 8))
indices = [0, 1, 2, len(mysolver.p_nn_save)//2]
for i in indices:
    p_nn_iter = mysolver.p_nn_save[i]
    nu_nn_iter = mysolver.nu_nn_save[i]
    p_value = p_nn_iter(Xspace_tensor).detach().numpy()
    nu_value = nu_nn_iter(Xspace_tensor).detach().numpy()
    ax[0].plot(Xspace, p_value, label=f"Checkpoint {i}", lw=2)
    ax[1].plot(Xspace, nu_value, label=f"Checkpoint {i}", lw=2)
terminal_p_value = p_nn(Xspace_tensor).detach().numpy()
terminal_nu_value = nu_nn(Xspace_tensor).detach().numpy()
ax[0].plot(Xspace, terminal_p_value, label="Last Iterate", lw=2)
ax[1].plot(Xspace, terminal_nu_value, label="Last Iterate", lw=2)

ax[0].plot(Xspace, benchmark_p_value, label="Benchmark Solution", lw=2, c='r', ls='--')
ax[1].plot(Xspace, benchmark_nu_value, label="Benchmark Solution", lw=2, c='r', ls='--')

ax[0].set_xlabel("X")
ax[0].set_ylabel("Value")
ax[0].set_title("Value function")
ax[1].set_xlabel("X")
ax[1].set_ylabel("Value")
ax[1].set_title("Distribution")

ax[0].legend()
ax[1].legend()

plt.show()

I’d really appreciate some insights!