Loss does not reduce, EDMD

Hello all,

I hope that you are well.

I am facing some problems to converge a solution using algorithm 1 (see attached image). The algorithm comes from https://arxiv.org/pdf/1707.00225.pdf

The idea of the algorithm is that there are 2 steps:

  • In the first step, K is calculated
  • In the second step, the Adam algorithm is applied

The loss is always big and sometimes increases.
Not sure how to help the algorithm to converge.

I was wondering if you have any ideas

Loss values for some iterations:

7800
loss tensor(5367.2446, grad_fn=<AddBackward0>)
loss tensor(5338.4434, grad_fn=<AddBackward0>)
loss tensor(5309.5991, grad_fn=<AddBackward0>)
loss tensor(5280.5703, grad_fn=<AddBackward0>)
loss tensor(5251.2197, grad_fn=<AddBackward0>)
loss tensor(5221.1494, grad_fn=<AddBackward0>)
loss tensor(5189.9116, grad_fn=<AddBackward0>)
loss tensor(5157.0610, grad_fn=<AddBackward0>)
loss tensor(5122.1064, grad_fn=<AddBackward0>)
loss tensor(5084.7554, grad_fn=<AddBackward0>)

8000
loss tensor(4591.7642, grad_fn=<AddBackward0>)
loss tensor(4557.1733, grad_fn=<AddBackward0>)
loss tensor(4527.3760, grad_fn=<AddBackward0>)
loss tensor(4502.6963, grad_fn=<AddBackward0>)
loss tensor(4483.3115, grad_fn=<AddBackward0>)
loss tensor(4469.2671, grad_fn=<AddBackward0>)
loss tensor(4460.4233, grad_fn=<AddBackward0>)
loss tensor(4456.5854, grad_fn=<AddBackward0>)
loss tensor(4457.3491, grad_fn=<AddBackward0>)
loss tensor(4462.3418, grad_fn=<AddBackward0>)
8100
loss tensor(4471.1421, grad_fn=<AddBackward0>)
loss tensor(4483.3481, grad_fn=<AddBackward0>)
loss tensor(4498.5205, grad_fn=<AddBackward0>)
loss tensor(4516.3730, grad_fn=<AddBackward0>)
loss tensor(4536.7710, grad_fn=<AddBackward0>)
loss tensor(4559.5410, grad_fn=<AddBackward0>)
loss tensor(4584.7847, grad_fn=<AddBackward0>)
loss tensor(4612.6025, grad_fn=<AddBackward0>)
loss tensor(4643.2603, grad_fn=<AddBackward0>)
loss tensor(4677.0620, grad_fn=<AddBackward0>)
8200
loss tensor(4714.4238, grad_fn=<AddBackward0>)
loss tensor(4755.7139, grad_fn=<AddBackward0>)
loss tensor(4801.4082, grad_fn=<AddBackward0>)
loss tensor(4852.0024, grad_fn=<AddBackward0>)
loss tensor(4907.9810, grad_fn=<AddBackward0>)
loss tensor(4969.7529, grad_fn=<AddBackward0>)
loss tensor(5037.8047, grad_fn=<AddBackward0>)
loss tensor(5112.5469, grad_fn=<AddBackward0>)
loss tensor(5194.3159, grad_fn=<AddBackward0>)
loss tensor(5283.3013, grad_fn=<AddBackward0>)
8300
loss tensor(5379.5381, grad_fn=<AddBackward0>)
loss tensor(5482.8652, grad_fn=<AddBackward0>)
loss tensor(5592.7178, grad_fn=<AddBackward0>)
loss tensor(5708.4941, grad_fn=<AddBackward0>)
loss tensor(5829.6284, grad_fn=<AddBackward0>)
loss tensor(5955.6792, grad_fn=<AddBackward0>)
loss tensor(6087.0273, grad_fn=<AddBackward0>)
loss tensor(6225.1846, grad_fn=<AddBackward0>)
loss tensor(6372.8101, grad_fn=<AddBackward0>)
loss tensor(6532.9771, grad_fn=<AddBackward0>)
8400
loss tensor(6708.2900, grad_fn=<AddBackward0>)
loss tensor(6899.4150, grad_fn=<AddBackward0>)
loss tensor(7104.4360, grad_fn=<AddBackward0>)
loss tensor(7319.9741, grad_fn=<AddBackward0>)
loss tensor(7541.7173, grad_fn=<AddBackward0>)
loss tensor(7766.3032, grad_fn=<AddBackward0>)
loss tensor(7992.1851, grad_fn=<AddBackward0>)
loss tensor(8219.3711, grad_fn=<AddBackward0>)
loss tensor(8448.2607, grad_fn=<AddBackward0>)
loss tensor(8678.2930, grad_fn=<AddBackward0>)
8500
loss tensor(8907.6895, grad_fn=<AddBackward0>)
loss tensor(9133.2422, grad_fn=<AddBackward0>)
loss tensor(9350.0273, grad_fn=<AddBackward0>)
loss tensor(9552.0527, grad_fn=<AddBackward0>)
loss tensor(9733.2754, grad_fn=<AddBackward0>)
loss tensor(9888.7441, grad_fn=<AddBackward0>)
loss tensor(10016.0166, grad_fn=<AddBackward0>)
loss tensor(10115.1289, grad_fn=<AddBackward0>)
loss tensor(10189.4238, grad_fn=<AddBackward0>)
loss tensor(10242.7695, grad_fn=<AddBackward0>)
8600
loss tensor(10280.3711, grad_fn=<AddBackward0>)
loss tensor(10306.0215, grad_fn=<AddBackward0>)
loss tensor(10323.3418, grad_fn=<AddBackward0>)
loss tensor(10334.4590, grad_fn=<AddBackward0>)
loss tensor(10340.8926, grad_fn=<AddBackward0>)
loss tensor(10344.0332, grad_fn=<AddBackward0>)
loss tensor(10345.1807, grad_fn=<AddBackward0>)
loss tensor(10346.1221, grad_fn=<AddBackward0>)
loss tensor(10348.2168, grad_fn=<AddBackward0>)
loss tensor(10353.4258, grad_fn=<AddBackward0>)
8700
loss tensor(10362.4160, grad_fn=<AddBackward0>)
loss tensor(10375.5801, grad_fn=<AddBackward0>)
loss tensor(10392.3262, grad_fn=<AddBackward0>)
loss tensor(10411.0020, grad_fn=<AddBackward0>)
loss tensor(10429.6211, grad_fn=<AddBackward0>)
loss tensor(10445.5811, grad_fn=<AddBackward0>)
loss tensor(10455.9043, grad_fn=<AddBackward0>)
loss tensor(10458.1855, grad_fn=<AddBackward0>)
loss tensor(10449.9951, grad_fn=<AddBackward0>)
loss tensor(10430.0352, grad_fn=<AddBackward0>)
8800
loss tensor(10397.1309, grad_fn=<AddBackward0>)
loss tensor(10351.6621, grad_fn=<AddBackward0>)
loss tensor(10294.1875, grad_fn=<AddBackward0>)
loss tensor(10226.2588, grad_fn=<AddBackward0>)
loss tensor(10149.6143, grad_fn=<AddBackward0>)
loss tensor(10066.1826, grad_fn=<AddBackward0>)
loss tensor(9977.8926, grad_fn=<AddBackward0>)
loss tensor(9886.4258, grad_fn=<AddBackward0>)
loss tensor(9792.8936, grad_fn=<AddBackward0>)
loss tensor(9698.7217, grad_fn=<AddBackward0>)

The code


#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 28 10:08:00 2022

@author: localadmin
"""

import os
import argparse
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import cloudpickle

import numpy as np
from scipy import linalg as la

from matplotlib import pyplot as plt
from scipy.stats import uniform
from statistics import mean


data_name = 'Duffing_oscillator'  # 'spectrum',  'Duffing_oscillator', 'Linear',Discrete_Linear Duffing_oscillator spectrum-1





lambda_ = 1e-2  

epsilon = 30

d = 2
l = 170  
M = 22
I = torch.eye(M + 3, M + 3)


N = 10000
inv_N = 1/N  #0.1

net = nn.Sequential(
    nn.Linear(d, l),
    nn.Tanh(),
    nn.Linear(l, l),
    nn.Tanh(),
    nn.Linear(l, l),
    nn.Tanh(),
    nn.Linear(l, M),
)


 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-5)
loss_fn = nn.MSELoss()

def data_Preprocessing(tr_val_te, cut):
    data = np.loadtxt(('./data/%s_%s.csv' % (data_name, tr_val_te)), delimiter=',', dtype=np.float64)[:cut]
    data = torch.tensor(data, dtype=torch.float32)
    return data


def Frobenius_norm(X):
    M = torch.mm(X, torch.transpose(X, 0, 1))
    return torch.sum(torch.diag(M, 0))


    
x = []
y = []
X = []
Y = []
K_tilde = []



#net input
x_data = data_Preprocessing("train_x", N)
y_data = data_Preprocessing("train_y", N)



count = 0
rotation = 50000




loss = float("INF")



while loss > epsilon  and count < rotation:
    if count % 100 == 0:
        print(count)
        
    optimizer.zero_grad(set_to_none=False)


    pred_sai = net(x_data)  
    y_pred_sai = net(y_data)

  

    # combines the neural and auxiliary non-trainable dictionary 
    
    
    # Non-trainable dictionary
    # Torch.Size([10000,3])
    
    
    fixed_sai = torch.tensor([i + [0.1] for i in x_data.detach().tolist()], dtype=torch.float32)
    y_fixed_sai = torch.tensor([i + [0.1] for i in y_data.detach().tolist()], dtype=torch.float32)
    


    
    # Non-trainable dictionary + trainable dictionary

    pred_sai = torch.cat([pred_sai, fixed_sai], dim=1)
    y_pred_sai = torch.cat([y_pred_sai, y_fixed_sai], dim=1)
    



    pred_sai_T = torch.transpose(pred_sai, 0, 1)

    G = inv_N * torch.matmul(pred_sai_T, pred_sai)  # 
    A = inv_N * torch.matmul(pred_sai_T, y_pred_sai)




    K_tilde = torch.mm(torch.pinverse(G + lambda_ * I), A)

    K_tilde = torch.tensor(K_tilde, requires_grad=False)

    Pred = torch.mm(K_tilde, pred_sai_T)
    
   
    y_pred_sai_T = torch.transpose(y_pred_sai, 0, 1)


    # loss calculation for the algorithm
    res = lambda_ * Frobenius_norm(K_tilde)
    MSE = (y_pred_sai_T - Pred)** 2  

    loss = torch.sum(MSE) + res
    
    
        
    y.append(loss)
    if count % 10 == 0:
        print("loss", loss)
        
    loss.backward()
    # torch.nn.utils.clip_grad_norm(net.parameters(), args.clip)

    optimizer.step()

    count += 1