Two inputs but results are only functions of one

Dear all,

I’ve set up an algorithm to solve a system of two equations, the equations feature two objects (v_values, g_values) that I model as neural networks (v_nn, g_nn), and both neural networks take two inputs (a_input, z_input). The equations are solved jointly, or at least that’s the idea, so v_values and g_values are equilibrium objects, if you will. The code features plots of the results, from which it’s clear that the networks are invariant to the a_input. I cannot find why that is. The equations feature derivatives of the neural networks, and of the neural networks multiplied by deterministic vectros, with respect to both inputs, I’ve printed them out and they seem comparable in size. The neural networks are in a standard feet forward architecture, the optimiser is Adam, the algorithm is some type of physics inspired neural network, if one were to classify it. I cannot see what it is that I’ve messed up, and I’ve been stuck for weeks. I would greatly appreciate some suggestions, since if I can’t resolve this I have to andandon this project. Below is the code.

import torch
import copy
import time
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import grad  
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import pandas as pd

rho = 0.05           # Discount rate
gamma = 2            # CRRA utility parameter
sigma_squared = 0.01 # sigma of O-U process
z_target = 1         # mean O-U process (in levels)
alpha = 0.35         # production function F = K^alpha * L^(1-alpha)
delta = 0.1          # Capital depreciation 10%

Corr = np.exp(-0.3)  # persistence -log(Corr) O-U process see page 34 of 'online appendix, not for publication' of Lions et al 2022
sigma = np.sqrt(2 * sigma_squared / (1 - Corr**2))  # sigma of O-U process

# I need to sample from the ergodic distribution of the state vector (a,z) to get the initial values for the Xs

I=128               # number of a points
J=128               # number of z points
zmin = 0.5          # Range z
zmax = 1.5
amin = -1           # range a
amax = 30

a = np.linspace(amin, amax, num=I)[:, np.newaxis]   # wealth vector 15/01 added a new axis to make it 2D so that mu is 2d
z = np.linspace(zmin, zmax, num=J)[:, np.newaxis]   # productivity vector

a_input = torch.from_numpy(a).float().unsqueeze(1)  # Convert to tensor and reshape to 2D this first converts the NumPy arrays `a` and `z` to PyTorch tensors using `torch.from_numpy` and `.float()`. It then reshapes these 1D tensors to 2D tensors using `unsqueeze(1)`. The `unsqueeze(1)` operation adds an extra dimension to the tensors, converting them from shape `(n,)` to `(n, 1)`. This is necessary because many PyTorch operations expect 2D inputs.
z_input = torch.from_numpy(z).float().unsqueeze(1)  # Convert to tensor and reshape to 2D
a_input = torch.from_numpy(a).float()   #also 128x1
z_input = torch.from_numpy(z).float()
z_input, _ = z_input.sort()
a_input, _ = a_input.sort()

k = 3.7404                          # solution to steady state capital -ordinarily this would have to be solved by an outer loop
r = alpha * k**(alpha - 1) - delta  # initial guess for interest rates 
w = (1 - alpha) * k**(alpha)        # initial guess for wages from M-N
# Ornstein-Uhlenbeck in levels
the = -(np.log(Corr))
Var = sigma_squared / (2 * the)

mu = torch.tensor(the) * (z_target - z)  # Convert 'the' to a PyTorch tensor
s2 = torch.tensor(sigma_squared * np.ones(J))

a_bar= 1
KAPPA=3

batchSize = 128 #should be a power of 2 (can try 256, 512)

# Value function initial guess
VFInitGuess = -8 
# Set global seed

torch.manual_seed(1234)
np.random.seed(1234)

class Exp(nn.Module):
    def forward(self, input):
        return torch.exp(input)

class V_NN(nn.Module):
    def __init__(self, initGuess=0):
        super(V_NN, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
            
            
        )
        nn.init.constant_(self.layers[-1].bias, initGuess)  # Initialize the last layer bias to the initial guess
        

        # Apply Xavier initialization
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)

    def forward(self, a_input, z_input):
        if len(a_input.shape) == 1:
            a_input = a_input.unsqueeze(1)
        if len(z_input.shape) == 1:
            z_input = z_input.unsqueeze(1)
        x = torch.cat((a_input, z_input), dim=1) #14/01
        return self.layers(x)

class g_NN(nn.Module):
    def __init__(self):
        super(g_NN, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1),
            nn.Softplus() #output of nn will always be +ve

        )
        # Apply Xavier initialization
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)


    def forward(self, a_input, z_input):
        if len(a_input.shape) == 1:        #all this on 14/01
            a_input = a_input.unsqueeze(1)
        if len(z_input.shape) == 1:
            z_input = z_input.unsqueeze(1)
        x = torch.cat((a_input, z_input), dim=1) #14/01
        return self.layers(x)

        
class DGMSolver:

    def __init__(self, v_neuralnet, g_neuralnet):    
        
        self.v_nn = v_neuralnet
        self.g_nn = g_neuralnet
        
        self.loss_train, self.loss_val = [], []
        self.loss_hjb_train, self.loss_hjb_val = [], []
        self.loss_kfp_train, self.loss_kfp_val = [], []

        self.v_nn_save = []
        self.g_nn_save = []

    def sampling(self, n_sample, a, amin, amax, z, zmin, zmax):
        a_input = amin + (amax - amin) * torch.rand((n_sample, 1))
        z_input = zmin + (zmax - zmin) * torch.rand((n_sample, 1))
        X0_sample = torch.cat((a_input, z_input), dim=1)
        #print("X0_sample:", X0_sample)
        return X0_sample

    def derivatives(self, nn, X):
        a_input, z_input = X.unbind(1)
        a_input.requires_grad_(True)
        z_input.requires_grad_(True)
        v = nn(a_input, z_input)
        dv_da = torch.autograd.grad(v, a_input, grad_outputs=torch.ones_like(v), create_graph=True)[0]
        dv_dz = torch.autograd.grad(v, z_input, grad_outputs=torch.ones_like(v), create_graph=True)[0]
        d2v_dz2 = torch.autograd.grad(dv_dz, z_input, grad_outputs=torch.ones_like(dv_dz), create_graph=True)[0]
        d2v_da2 = torch.autograd.grad(dv_da, a_input, grad_outputs=torch.ones_like(dv_da), create_graph=True)[0]
        return v, dv_da, dv_dz, d2v_dz2, d2v_da2

   
    def deriv_density_times_savings(self, nn, s, a_input, z_input):
        a_input.requires_grad_(True)
        z_input.requires_grad_(True)
        g_s = nn(a_input, z_input) * s
        dg_s_da = torch.autograd.grad(g_s, a_input, grad_outputs=torch.ones_like(g_s), create_graph=True)[0]

        return g_s, dg_s_da

    def deriv_density_times_drift(self, nn, the, z_target, X):
        a_input, z_input = X.unbind(1)
        z_input.requires_grad_(True)
        density_times_drift = the * (z_target - z_input) * nn(a_input, z_input)
        density_times_drift_dz = torch.autograd.grad(density_times_drift, z_input, grad_outputs=torch.ones_like(density_times_drift), create_graph=True)[0]
        return density_times_drift_dz

    def SecDeriv_density_times_sigmasq(self, nn, sigma_squared, X):
        a_input, z_input = X.unbind(1)
        z_input.requires_grad_(True)
        g_sigma_squared = nn(a_input, z_input) * sigma_squared #there should be only z_input here, right? I think
        dg_sigma_squared_dz = torch.autograd.grad(g_sigma_squared, z_input, grad_outputs=torch.ones_like(g_sigma_squared), create_graph=True)[0]
        d2g_sigma_squared_dz2 = torch.autograd.grad(dg_sigma_squared_dz, z_input, grad_outputs=torch.ones_like(dg_sigma_squared_dz), create_graph=True)[0]
        return d2g_sigma_squared_dz2

    def loss_fn(self, sample):
        self.a_input, self.z_input = sample[:, 0].unsqueeze(1), sample[:, 1].unsqueeze(1)
        loss_value = 0

        # Constants and parameters
        lambda_reg = 0.5
        lambda_reg2 = 0.5
        lambda_reg3 = 0.5
        gamma = 2
        KAPPA = 3
        a_bar = 1
        rho = 0.05
        sigma_squared = 0.01
        z_target = 1
        alpha = 0.35
        delta = 0.1
        Corr = np.exp(-0.3)
        sigma = np.sqrt(2 * sigma_squared / (1 - Corr ** 2))
        the = -(np.log(Corr))
        Var = sigma_squared / (2 * the)
        mu = torch.tensor(the) * (z_target - z_input)
        s2 = torch.tensor(sigma_squared * np.ones(J)).unsqueeze(1) #15/01 to reshape s2 to []128x1] instead of [128]

        # Compute derivatives for v and s*g
        v, dv_da, dv_dz, d2v_dz2, d2v_da2 = self.derivatives(self.v_nn, torch.stack((a_input, z_input), dim=1))
        #print("dv_dz:", dv_dz)
        #print("dv_da:", dv_da)
        epsilon = 1e-8
        c = (dv_da.abs() + epsilon)**(-1/gamma)
        U = c.pow(1 - gamma) / (1 - gamma)
        penalty = torch.where(a_input <= a_bar, -0.5 * KAPPA * (a_input - a_bar) ** 2, torch.tensor(0.0))
        U += penalty

        w = (1 - alpha) * k**alpha
        r = alpha * k ** (alpha - 1) - delta
        self.s = w*z_input + r*a_input - c
        self.c = c    
        
        _, dg_s_da = self.deriv_density_times_savings(self.g_nn, self.s, a_input, z_input)
        #print("dg_s_da:", dg_s_da)
        d2g_sigma_squared_dz2 = self.SecDeriv_density_times_sigmasq(self.g_nn, sigma_squared, torch.stack((a_input, z_input), dim=1))
        density_times_drift_dz = self.deriv_density_times_drift(self.g_nn, the, z_target, torch.stack((a_input, z_input), dim=1))
        #print("density_times_drift_dz:", density_times_drift_dz)

        pde_residual_HJB = U + (w * z_input + r * a_input - c) * dv_da + mu * dv_dz + 0.5 * sigma_squared * d2v_dz2 - rho * v

        pde_residual_HJB = torch.mean(torch.square(pde_residual_HJB))

        loss_value += pde_residual_HJB

        # Compute pde_residual_KFP

        pde_residual_KFP = -dg_s_da - density_times_drift_dz + 0.5 * d2g_sigma_squared_dz2
        pde_residual_KFP = torch.mean(torch.square(pde_residual_KFP))

        loss_value += pde_residual_KFP

        penalty_term = torch.relu(d2v_da2).mean() #12/2/24 #relu function creates a term that is zero when d2v_da2 is less than or equal to zero
        #and increases linearly as d2v_da2 increases, so it penalises d2v_da2 values that are greater than zero

        # Print the values of d2v_da2 and penalty_term
        #print("d2v_da2: ", d2v_da2)
        #print("penalty_term: ", penalty_term)
        
        loss_value += lambda_reg * penalty_term   #12/2/24
        
        _, dv_da, _, _, _ = self.derivatives(self.v_nn, torch.stack((a_input, z_input), dim=1))
    
        # Compute penalty term for negative dv_da
        penalty_term2 = torch.relu(-dv_da).mean()  # relu function creates a term that is zero when dv_da is non-negative
         # and increases linearly as dv_da decreases, so it penalises dv_da values that are negative
          # Add penalty term to loss
        loss_value += lambda_reg3 * penalty_term2
        #print("dv_da:", dv_da)
        #print("penalty_term2: ", penalty_term2)
        
        # Compute derivatives at zmin and zmax
        _, _, dv_dz_at_zmin, _, _ = self.derivatives(self.v_nn, torch.stack((a_input, torch.tensor([[zmin]]*len(a_input))), dim=1))
        _, _, dv_dz_at_zmax, _, _ = self.derivatives(self.v_nn, torch.stack((a_input, torch.tensor([[zmax]]*len(a_input))), dim=1))

        # Compute the mean squared error of the derivatives at zmin and zmax
        mse_term = torch.mean((dv_dz_at_zmin)**2 + (dv_dz_at_zmax)**2)


        # Add regularization term to loss
        loss_value += lambda_reg2 * mse_term


        return loss_value, pde_residual_HJB, pde_residual_KFP


    def one_step_grad(self, sample, optimizer):
        # Zero the gradients

        optimizer.zero_grad()
        # Compute the loss
        loss_value, pde_residual_KFP, pde_residual_HJB = self.loss_fn(sample)

        # Backpropagate the gradients
        loss_value.backward()
        # Clip the gradients
        torch.nn.utils.clip_grad_norm_(self.v_nn.parameters(), max_norm=0.01)
        torch.nn.utils.clip_grad_norm_(self.g_nn.parameters(), max_norm=0.01)
        

        # Apply the gradients
        optimizer.step()

        return loss_value, pde_residual_KFP, pde_residual_HJB #, bc_residual_period, bc_residual_normalization

    def train(self, optimizer, n_iterations, n_sample_train):
        start_time = time.time()
        # SGD ITERATIONS
        for step in range(0, n_iterations + 1):
            training_sample = self.sampling(n_sample_train, a, amin, amax, z, zmin, zmax) #13/01 18/01 commented out

            loss_train = self.one_step_grad(training_sample, optimizer)
            loss_value, kfp, hjb = loss_train
            self.loss_train.append(loss_value.item())
            self.loss_hjb_train.append(hjb)
            self.loss_kfp_train.append(kfp)
 

            if (step % 1000 == 0):
                    print("iteration = {}, \t\t loss = {}, \t\t total time = {}".format(step, loss_value.item(),
                                                                                    time.time() - start_time))

            if (step % 10000 == 0):
                    self.v_nn_save.append(copy.deepcopy(self.v_nn))
                    self.g_nn_save.append(copy.deepcopy(self.g_nn))
        return self.v_nn, self.g_nn#, self.lambda_nn

# Instantiate the neural networks and the solver
v_nn = V_NN(initGuess=VFInitGuess) #network is initialized
g_nn = g_NN()

mysolver = DGMSolver(v_neuralnet=v_nn,g_neuralnet=g_nn)
adam_optimizer = torch.optim.Adam(list(v_nn.parameters()) + list(g_nn.parameters())) #adam doesn't require that I fix a learninrate, it's adaptive and it's the best optimizer for NNs (see https://ruder.io/optimizing-gradient-descent/index.html#adam)

v_nn, g_nn  = mysolver.train(optimizer=adam_optimizer,
                             #n_iterations=230000,
                             n_iterations=2000,
                             n_sample_train=128)

loss_train_plot = mysolver.loss_train#[1000:]

loss_train_plot = mysolver.loss_train#[1000:]


fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("Total Loss vs Iterations")
ax.legend()
ax.set_yscale('log')

plt.savefig('1.png')
plt.show()

loss_train_plot = [tensor.detach().numpy() for tensor in mysolver.loss_kfp_train]#[10:
fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot)), loss_train_plot, label="Training", c="red", alpha=0.5)
ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("HJB Loss vs Iterations")

ax.legend()
ax.set_yscale('log')

plt.savefig('2.png')
plt.show()

loss_train_plot = torch.tensor(mysolver.loss_kfp_train)#[10:]

fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
ax.plot(np.arange(len(loss_train_plot.detach().numpy())), loss_train_plot.detach().numpy(), label="Training", c="red", alpha=0.5)

ax.set_xlabel("Iterations")
ax.set_ylabel("Value")
ax.set_title("KFP Loss vs Iterations")

ax.legend()
ax.set_yscale('log')

plt.savefig('3.png')
plt.show()

v_values = v_nn(a_input, z_input).detach().numpy()
g_values = g_nn(a_input, z_input).detach().numpy()

# Plot for value function
def plotValueFunction(a_input, z_input, v_values):
    a_input = a_input.detach().numpy()
    z_input = z_input.detach().numpy()
    # No need for the with torch.no_grad() block since v_values is already a NumPy array
    
    fig = plt.figure(figsize=(11, 7), dpi=100)
    ax = fig.add_subplot(111, projection='3d')  # Create subplot with 3D projection
    X, Y = np.meshgrid(a_input, z_input)
    surf = ax.plot_surface(X, Y, v_values, rstride=1, cstride=1, cmap=cm.viridis, linewidth=0, antialiased=False)
    ax.set_xlim(-1, 30)
    ax.set_ylim(0.5, 1.5)
    ax.view_init(30, 225)   
    ax.set_xlabel('$a_{input}$')
    ax.set_ylabel('$z_{input}$')
    
    # Add label to the vertical axis
    ax.text(0, 0.75, 0, "Value Function", color='red', fontsize=12, rotation='vertical')
    
    # Save the figure
    plt.savefig('plotValueFunction.png')
    plt.show()

# Plot for distribution
def plotDensity(a_input, z_input, g_values):
    a_input = a_input.detach().numpy()
    z_input = z_input.detach().numpy()
    # No need for the with torch.no_grad() block since g_values is already a NumPy array
    
    fig = plt.figure(figsize=(11, 7), dpi=100)
    ax = fig.add_subplot(111, projection='3d')  # Create subplot with 3D projection
    X, Y = np.meshgrid(a_input, z_input)
    surf = ax.plot_surface(X, Y, g_values, rstride=1, cstride=1, cmap=cm.viridis, linewidth=0, antialiased=False)
    ax.set_xlim(-1, 30)
    ax.set_ylim(0.5, 1.5)
    ax.view_init(30, 225)   
    ax.set_xlabel('$a_{input}$')
    ax.set_ylabel('$z_{input}$')
    
    # Add label to the vertical axis
    ax.text(0, 0.75, 0, "g(a,z)", color='red', fontsize=12, rotation='vertical')
    
    # Save the figure
    plt.savefig('plotDensity.png')
    plt.show()

# Example usage
plotValueFunction(a_input, z_input, v_values)
plotDensity(a_input, z_input, g_values)