Neural ODE's predator-prey toy example

Hello to all,

I am new to machine learning, and I have been trying to fit a neural ODE to the Lotka-Volterra ODE.
Based on the original publication of neural ODE’s. I have prepared the following code:

# Modules
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn

from torchdiffeq import odeint, odeint_adjoint

import os

def makedirs(dirname):
    if not os.path.exists(dirname):


def visualize(t, y, y_pred=None, iteration=None):
    fig = plt.figure(figsize=(8, 5), facecolor='white')
    ax1 = fig.add_subplot(111)
    ax1.plot(t.cpu().numpy(), y.cpu().numpy()[:, 0, 0], 'k', label=r"True $u_1$")
    ax1.plot(t.cpu().numpy(), y.cpu().numpy()[:, 0, 1], 'g', label=r"True $u_2$")
    if y_pred is not None:
        ax1.plot(t.cpu().numpy(), y_pred.cpu().numpy()[:, 0, 0], 'r--', label=r"Learned $u_1$")
        ax1.plot(t.cpu().numpy(), y_pred.cpu().numpy()[:, 0, 1], 'b--', label=r"Learned $u_1$")
    ax1.set_xlim(0, 15)
    ax1.set_ylim(0, 4)
    ax1.set_ylabel(r"$u_1$, $u_2$")

    if iteration is not None:

# Classes for neural ODE

n_hidden = 64
class NeuralODE(nn.Module):
    def __init__(self):
        super(NeuralODE, self).__init__()
        = nn.Sequential(
        nn.Linear(2, n_hidden),
        nn.Linear(n_hidden, n_hidden),
        nn.Linear(n_hidden, n_hidden),
        nn.Linear(n_hidden, 2),)
        for m in
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):

# Ground-truth dynamic of the system
class Lambda(nn.Module):
   def forward(self, t, y):
       a, b, c, d = 1, 2, 3, 4
       x1 = y[0][0]
       x2 = y[0][1]
       return torch.tensor([[a*x1 - b*x1*x2, 
                           -c*x2 + d*x1*x2]]).to(device)

def get_batch():
    s = torch.from_numpy(np.random.choice(np.arange(data_size - batch_time, dtype=np.int64), batch_size, replace=False))
    batch_y0 = true_y[s]  # (batch_size, 1, vector)
    batch_t = t[:batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(batch_time)], dim=0)  # (time, batch_size, 1, vector)

# Create the data set
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_size = 1000
true_y0 = torch.tensor([[1., 3.]]).to(device)
t = torch.linspace(0., 15., data_size).to(device)

# Batch info
batch_time = 10
batch_size = 20

with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')    
# Visualize the data
visualize(t, true_y, iteration=0)

# train the network
func = NeuralODE().to(device)

# loss_fn = torch.nn.MSELoss()
loss_fn = torch.nn.L1Loss()

optimizer = torch.optim.Adam(func.parameters(), lr=1e-3)

epochs = 2000
losses = []
losses_val = []
odeMethod = 'euler'

print(f'Working on: {device}')

for epoch in range(1, epochs + 1):
    initial_states, times, targets = get_batch()
    pred_Y = odeint_adjoint(func = func, y0=initial_states, t=times, method=odeMethod)
    loss = loss_fn(pred_Y, targets)
    if epoch % 50 == 0:
        with torch.no_grad():
            pred_y = odeint(func, true_y0, t, method=odeMethod)
            loss = loss_fn(pred_y, true_y)
            print(f'Iteration {epoch} | Total loss {loss.item():.6f}')
            visualize(t, true_y, pred_y, epoch)

However, the fitting I am getting are not very good :confused: . The loss function is switching between high and low values. (Within my novice knowledge I tried playing with some parameters (learn-rate, batch-size) but so far to no avail.
I was hopping to get some the loss function to decrease with training and with each epoch get closer and closer to the actual values. Can anyone help me with this?

Best Regards

Can anyone help me with this?

Your code seems to be mostly correct, but I can suggest a few adjustments that might help stabilize the training and improve the results.

  1. Change the learning rate: It’s possible that the learning rate is too high, which causes instability in the training process. Try reducing the learning rate, for example, to 1e-4 or even lower.
  2. Use weight decay: Adding weight decay (L2 regularization) to your optimizer might help the model to generalize better. You can add weight decay to the Adam optimizer like this:
optimizer = torch.optim.Adam(func.parameters(), lr=1e-4, weight_decay=1e-5)

  1. Decrease the batch size: Using a smaller batch size might help stabilize the training process. You can try reducing the batch size, for example, to 10 or even smaller.
  2. Increase the number of hidden units: You can try increasing the number of hidden units in your neural network (e.g., to 128 or 256) to give the model more capacity to learn the dynamics of the system.
  3. Train for more epochs: Sometimes, the model might need more epochs to converge to a good solution. You can try increasing the number of training epochs.
  4. Initialize weights more conservatively: The initialization of the weights can have a significant impact on the training process. You can try initializing the weights with a smaller standard deviation, like this:
if isinstance(m, nn.Linear):
    nn.init.normal_(m.weight, mean=0, std=0.01)
    nn.init.constant_(m.bias, val=0)