I’m trying to use LSTM module to predict a rather simple sequence. Basically, the network receives 20 time steps, and I want to predict steps 1 to 21.
I thought my model was fine, but I can’t get the loss to decrease. Could someone help ?
Here’s my code:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('dark_background')
x = np.linspace(0,30.,500)
y = x*np.sin(x) + 2*np.sin(5*x)
nb_steps = 20
class LSTM(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.lstm = nn.LSTM(1,100)
self.head = nn.Linear(100,1)
def forward(self,x):
outputs, states = self.lstm(x)
outputs = outputs.reshape(x.shape[0]*x.shape[1], -1)
pred = self.head(outputs)
return pred
def load_batch(batch_size = 32):
x_b = np.zeros((nb_steps,batch_size,1))
y_b = np.zeros((nb_steps*batch_size,1))
inds = np.random.randint(0, 479, (batch_size))
for i,ind in enumerate(inds):
x_b[:,i,0] = y[ind:ind+nb_steps]
y_b[i*nb_steps:(i+1)*nb_steps,0] = y[ind+1:ind+nb_steps+1]
return torch.tensor(x_b).float(), torch.tensor(y_b).float()
rnn = LSTM()
adam = optim.Adam(rnn.parameters(), 1e-3)
epochs = 1000
batch_size = 32
mean_loss = 0.
for epoch in range(1,epochs+1):
x_b,y_b = load_batch(batch_size)
pred = rnn(x_b)
shaped_pred = pred.reshape(-1,1)
loss = F.mse_loss(shaped_pred, y_b)
adam.zero_grad()
loss.backward()
adam.step()
mean_loss += loss.item()
if epoch%100 == 0:
print('Epoch: {} | Loss: {:.6f}'.format(epoch, mean_loss/100.))
mean_loss = 0.
f, ax = plt.subplots(2,1)
while True :
x_b, y_b = load_batch(1)
pred = rnn(x_b).detach().numpy().reshape(-1)
ax[0].plot(x,y, label= 'Real')
ax[0].plot(x_b.numpy().reshape(-1),y_b.numpy().reshape(-1), label= 'Real batch')
ax[0].plot(x_b.numpy().reshape(-1), pred, label = 'Pred')
ax[1].scatter(x_b.numpy().reshape(-1),y_b.numpy().reshape(-1), label= 'Real')
ax[1].scatter(x_b.numpy().reshape(-1), pred, label = 'Pred')
for a in ax: a.legend()
plt.pause(0.1)
input()
for a in ax:
a.clear()
I had posted some code previous and realized it was wrong. The dangers of monkey patching. I figured it out. While this isn’t perfect, I did get the loss from 14000 to 5000 on my computer with this code. I am not sure it is right though. It doesn’t make my computer work hard enough for 1000 epochs. Maybe someone else can help with an explanation.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('dark_background')
x = np.linspace(0,30.,500)
y = x*np.sin(x) + 2*np.sin(5*x)
nb_steps = 20
class LSTM(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.lstm = nn.LSTM(1,100)
self.head = nn.Linear(100,1)
def forward(self,x):
outputs, states = self.lstm(x)
outputs = outputs.reshape(x.shape[0]*x.shape[1], -1)
pred = self.head(outputs)
return pred
def load_batch(batch_size = 32):
x_b = np.zeros((nb_steps,batch_size,1))
y_b = np.zeros((nb_steps*batch_size,1))
inds = np.random.randint(0, 479, (batch_size))
for i,ind in enumerate(inds):
x_b[:,i,0] = y[ind:ind+nb_steps]
y_b[i*nb_steps:(i+1)*nb_steps,0] = y[ind+1:ind+nb_steps+1]
return torch.tensor(x_b).float(), torch.tensor(y_b).float()
rnn = LSTM()
epochs = 1000
batch_size = 32
criterion = nn.MSELoss()
mean_loss = 0.
for epoch in range(1,epochs+1):
x_b,y_b = load_batch(batch_size)
# pred = rnn(x_b)
# loss = F.mse_loss(abs(shaped_pred), abs(y_b))
# print(loss)
# loss.backward()
def closure():
global loss
optimizer.zero_grad()
pred = rnn(x_b)
shaped_pred = pred.reshape(-1,1)
loss = criterion(abs(shaped_pred), abs(y_b))
# print('loss:', loss.item())
loss.backward()
return loss
optimizer = optim.Adam(rnn.parameters(), 1e-3)
optimizer.step(closure)
mean_loss += loss.item()
if epoch%100 == 0:
print('Epoch: {} | Loss: {:.6f}'.format(epoch, mean_loss))
mean_loss = 0
f, ax = plt.subplots(2,1)
while True :
x_b, y_b = load_batch(1)
pred = rnn(x_b).detach().numpy().reshape(-1)
ax[0].plot(x,y, label= 'Real')
ax[0].plot(x_b.numpy().reshape(-1),y_b.numpy().reshape(-1), label= 'Real batch')
ax[0].plot(x_b.numpy().reshape(-1), pred, label = 'Pred')
ax[1].scatter(x_b.numpy().reshape(-1),y_b.numpy().reshape(-1), label= 'Real')
ax[1].scatter(x_b.numpy().reshape(-1), pred, label = 'Pred')
for a in ax: a.legend()
plt.pause(0.1)
input()
for a in ax:
a.clear()