# LSTM for sequence prediction

Hey !

I’m trying to use LSTM module to predict a rather simple sequence. Basically, the network receives 20 time steps, and I want to predict steps 1 to 21.
I thought my model was fine, but I can’t get the loss to decrease. Could someone help ?

Here’s my code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
plt.style.use('dark_background')

x = np.linspace(0,30.,500)
y = x*np.sin(x) + 2*np.sin(5*x)

nb_steps = 20

class LSTM(nn.Module):

def __init__(self):

nn.Module.__init__(self)

self.lstm = nn.LSTM(1,100)

def forward(self,x):

outputs, states = self.lstm(x)
outputs = outputs.reshape(x.shape[0]*x.shape[1], -1)

return pred

def load_batch(batch_size = 32):

x_b = np.zeros((nb_steps,batch_size,1))
y_b = np.zeros((nb_steps*batch_size,1))

inds = np.random.randint(0, 479, (batch_size))
for i,ind in enumerate(inds):
x_b[:,i,0] = y[ind:ind+nb_steps]
y_b[i*nb_steps:(i+1)*nb_steps,0] = y[ind+1:ind+nb_steps+1]

rnn = LSTM()

epochs = 1000
batch_size = 32

mean_loss = 0.
for epoch in range(1,epochs+1):

pred = rnn(x_b)
shaped_pred = pred.reshape(-1,1)
loss = F.mse_loss(shaped_pred, y_b)

loss.backward()

mean_loss += loss.item()

if epoch%100 == 0:
print('Epoch: {} | Loss: {:.6f}'.format(epoch, mean_loss/100.))
mean_loss = 0.

f, ax = plt.subplots(2,1)

while True :

x_b, y_b = load_batch(1)
pred = rnn(x_b).detach().numpy().reshape(-1)

ax[0].plot(x,y, label= 'Real')
ax[0].plot(x_b.numpy().reshape(-1),y_b.numpy().reshape(-1), label= 'Real batch')
ax[0].plot(x_b.numpy().reshape(-1), pred, label = 'Pred')

ax[1].scatter(x_b.numpy().reshape(-1),y_b.numpy().reshape(-1), label= 'Real')
ax[1].scatter(x_b.numpy().reshape(-1), pred, label = 'Pred')

for a in ax: a.legend()

plt.pause(0.1)
input()

for a in ax:
a.clear()

sorry for the bad format, I am still learning this site.

I had posted some code previous and realized it was wrong. The dangers of monkey patching. I figured it out. While this isn’t perfect, I did get the loss from 14000 to 5000 on my computer with this code. I am not sure it is right though. It doesn’t make my computer work hard enough for 1000 epochs. Maybe someone else can help with an explanation.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('dark_background')

x = np.linspace(0,30.,500)
y = x*np.sin(x) + 2*np.sin(5*x)

nb_steps = 20

class LSTM(nn.Module):

def __init__(self):

nn.Module.__init__(self)

self.lstm = nn.LSTM(1,100)

def forward(self,x):

outputs, states = self.lstm(x)
outputs = outputs.reshape(x.shape[0]*x.shape[1], -1)

return pred

def load_batch(batch_size = 32):

x_b = np.zeros((nb_steps,batch_size,1))
y_b = np.zeros((nb_steps*batch_size,1))

inds = np.random.randint(0, 479, (batch_size))
for i,ind in enumerate(inds):
x_b[:,i,0] = y[ind:ind+nb_steps]
y_b[i*nb_steps:(i+1)*nb_steps,0] = y[ind+1:ind+nb_steps+1]

rnn = LSTM()

epochs = 1000
batch_size = 32
criterion = nn.MSELoss()
mean_loss = 0.
for epoch in range(1,epochs+1):

#    pred = rnn(x_b)

#    loss = F.mse_loss(abs(shaped_pred), abs(y_b))
#    print(loss)
#    loss.backward()

def closure():
global loss
pred = rnn(x_b)
shaped_pred = pred.reshape(-1,1)
loss = criterion(abs(shaped_pred), abs(y_b))
#        print('loss:', loss.item())
loss.backward()

return loss

optimizer = optim.Adam(rnn.parameters(), 1e-3)
optimizer.step(closure)
mean_loss += loss.item()

if epoch%100 == 0:
print('Epoch: {} | Loss: {:.6f}'.format(epoch, mean_loss))
mean_loss = 0

f, ax = plt.subplots(2,1)

while True :

x_b, y_b = load_batch(1)
pred = rnn(x_b).detach().numpy().reshape(-1)

ax[0].plot(x,y, label= 'Real')
ax[0].plot(x_b.numpy().reshape(-1),y_b.numpy().reshape(-1), label= 'Real batch')
ax[0].plot(x_b.numpy().reshape(-1), pred, label = 'Pred')

ax[1].scatter(x_b.numpy().reshape(-1),y_b.numpy().reshape(-1), label= 'Real')
ax[1].scatter(x_b.numpy().reshape(-1), pred, label = 'Pred')

for a in ax: a.legend()

plt.pause(0.1)
input()

for a in ax:
a.clear()

I was able to get the losses down by using a batch size of 1 and running for 5000 epochs. I think now it is just an issue of tuning.