Modelling Hidden Dynamics using RNN based ODE

Hello everyone,

I’m trying to model the hidden dynamics of a system to predict continuous-time behavior (trajectory) of the output by solving an ODE-based-RNN (ODE: Ordinary Differential Equation)


The idea and algorithms in detail found here:
the paper
torchdiffeq library

Idea in nutshell

Basically, I’m trying to move from a standard RNN model which can learn discrete behavior to a general model that learns and predicts continuous-time behavior. The idea is to train a network to be to learn the changes in hidden states then by using accurate solvers, IVP could be solved to get states/output at instants of evaluation.

step 0

#import modules
import os
import argparse

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time

import torch
import torch.nn as nn
import torch.optim as optim
#from torchdiffeq import odeint_adjoint as odeint #backprop. using adjoint method integrated in odeint
from torchdiffeq import odeint as odeint

from torch.utils.tensorboard import SummaryWriter
import shutil

#use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

step 1

#create train dataset

#read csv dataframes
df_train = (pd.read_csv('some_csv_file')) 

#split dataframes into time seria
df_train.columns = ['T','in','out']

#assume a number divisible by data length for now
nbatchs = 100

#create sequence batchs as tensors
t_train = torch.tensor(1e6*df_train[['T']].values.reshape(nbatchs,-1,1)).to(device)   #in us
x_train = torch.tensor(df_train[['in']].values.reshape(nbatchs,-1,1)).float().to(device)  #in V
y_train = torch.tensor(df_train[['out']].values.reshape(nbatchs,-1,1)).float().to(device) #in V

#check sizes

step 2

#define a fn that handles data and returns:
#data= list([init_x[i], init_y[i], time_series[i], targets[i]] for i in range(nbatchs))
#init_state=  torch.zeros(1, 1, hidden_size)

data_size = 1000
eval_pts = 10                            #no. of eval pts for integration
seq_len = int(data_size/nbatchs)        #batch length
s = int(seq_len/eval_pts)               #sampling rate
niters = 1000                           #no. of iterations
test_freq = 2                           #test frequency
hidden_size = 10                        #size of hidden layer

def get_data():    
    x0 = list([x_train[batch,0].view(-1,1,1) for batch in range(nbatchs)])
    y0 = list([y_train[batch,0].view(-1,1,1) for batch in range(nbatchs)])                                                                                     
    t = list([t_train[batch,::s].view(-1) for batch in range(nbatchs)])
    y = list([y_train[batch,::s].view(-1,1,1,1) for batch in range(nbatchs)])
    data= list([x0[i], y0[i], t[i], y[i]] for i in range(nbatchs))
    init_state = torch.zeros(1, 1, hidden_size)
    targets = y
    return data, init_state, targets

step 3
Thanks to @albanD:

#This class trains func -> (dy/dt) and solves for y at predefined eval_pts 

tot_loss= 0.0
class ODE_RNN_TBPTT():
    def __init__(self, func, loss_fn, k, optimizer):
        self.func = func
        self.loss_fn = loss_fn
        self.k = k
        self.optimizer = optimizer

    def train(self, data, init_state):
        global tot_loss
        h0 = init_state 
        #save prev hidden states      
        states = [(None, h0)]             
        #iterate on batches
        for batch, (x0, y0, t, targets) in enumerate(data):     
            #call get_new_observation
            #detach state from grad computation graph
            state = states[-1][1].detach()
            #run solver on the batch which will call func.forward() under the hood 
            pred, new_state = odeint(self.func, tuple([y0, state]), t)
            #append the new_state
            states.append((state, new_state[-1].view(1, 1, -1)))

            if (batch+1)%self.k == 0:
                loss = self.loss_fn(pred, targets)
                tot_loss = tot_loss + loss

step 4

class NN_Module(nn.Module):
    def __init__(self, hidden_size):
        #net layers
        self.rnn= nn.RNN(1, hidden_size, batch_first=True)
        self.dense= nn.Linear(hidden_size, 1)        

    def get_new_observation(self, x0):
        self.x0= x0

    def forward(self, t, init_conditions):
        #global idx
        #RNN update equations
        x0= self.x0
        y0, h0= init_conditions
        ht, hn= self.rnn(x0, h0)
        y= self.dense(ht)
        f= tuple([y, hn])                     #f is a tuple(dy/dt, dh/dt) at t=T where T whenever the solver is evaluating

step 5

#2 main steps
#1. make an instant of ODE_RNN_TBPTT: this is supposed to train a NN wrapped inside an odeint (Ordinary Differential Equation Integral)
#2. call get_data() and feed the trainer

func = NN_Module(hidden_size).to(device)                    #func implements f(t,y(t),params) representing dy/dt as NN
loss_fn = nn.MSELoss()                                      #loss criterion
k = 1                                                      #k1,k2 are no. of batchs per gradient update; assume k1=k2 for now
optimizer = torch.optim.Adam(func.parameters(), lr=1e-3)    
trainer = ODE_RNN_TBPTT(func, loss_fn, k, optimizer)

#clear logs
writer = SummaryWriter('runs')                                         #create a logger

#test loop idx
ii= 0

for itr in range(1, niters + 1):
    tot_loss = 0.0                                                      #loss per itr
    data, init_state, targets = get_data()                              #get training data
    trainer.train(data, init_state)                                     #feed the trainer
    print("itr: {0} | loss: {1}".format(itr,tot_loss))
    writer.add_scalar('loss', tot_loss, itr )                           #log to writer 

I was able to get results of training but usually loss is huge and descending with very small steps !
I tried also GRU instead of VRNN.
Update: The gradients are vanishing even for very small batches. I guess that’s because back prop through odeint takes into consideration all the steps the integral do.

itr: 1 | loss: 4.284154891967773
itr: 2 | loss: 4.283952236175537
itr: 3 | loss: 4.283847808837891
itr: 4 | loss: 4.283742904663086
itr: 5 | loss: 4.283634662628174
itr: 6 | loss: 4.283525466918945
itr: 7 | loss: 4.283415794372559


Is there any obvious methodological mistakes in coding or the problem with the technique of training and I need to find another way to handle data or change network structure.
Any advice or recommendations are welcomed.