I’m trying to predict the sine wave with RNN. Even if the out put learns the trend of the curve, the output seems it is in other scale as you can see in the next plot:
Then a couple of questions jumped out of my mind:
- Since the range of the input is from 1 to -1, is it necessary to normalise?
- Changing the optimizer from SGD to Adam changed the scale of the output drastically like from 4.5 to -4.5. I can’t understand why. (I wanted to post a second image showing it but I can’t since I’m a new user)
I think I’m missing a super basic concept on Machine Learning and some Data Science skills as well to tackle this problem so any help would be greatly appreciated.
Data
wave = np.sin(np.arange(0,8*np.pi,0.1))
wave = torch.from_numpy(wave).repeat(3).float()
Model
class RNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(RNN, self).__init__()
self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, input, hidden):
output, hidden = self.rnn(input, hidden)
output = self.fc(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
Hyperparameters
model=RNN(input_size=4, hidden_size=256)
learning_rate = 1e-2
epochs=100
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()
Training loop
min_loss=np.inf
model.train()
loss_train=0.0
for epoch in range(0, epochs):
hidden = torch.zeros(1, 256)
optimizer.zero_grad()
x, hidden = model(train.view(train.shape[0] // 4, 4), hidden)
loss = loss_fn(x[-1], test[0].unsqueeze(0))
loss.backward()
optimizer.step()
loss_train+=loss.item()
if epoch%10 == 0:
print('Epoch: {}/{}.............'.format(epochs, epoch), end=' ')
print("Loss: {:.4f}".format(loss.item()))
Evaluation
This is how I generate the output which I’ve shown before.
src = train[-4:].unsqueeze(0)
results = []
hidden = torch.zeros(1, 256)
model.eval()
with torch.no_grad():
for i in range(0, test.size()[0]):
src = torch.cat((train[i+1::], test[:i+1])).clone()
o, h = model(src.view(src.shape[0] // 4, 4), hidden)
results+=o
Thank you.