LSTM for predicting sine wave

Hi, i am trying to create a LSTM model for predicting a sine wave function. I want to eventually be able to implement this using only basic operations within the autograd system.

But for testing out things i am using the torch.nn.LSTM for now.

import torch
import matplotlib.pyplot as plt

# Parameters
INPUT_LEN = 400
SEQ_LEN = 20
INPUT_SIZE = 1
NUM_LAYERS = 1
HIDDEN_SIZE = 1
BATCH_SIZE = 40

# training data
X = (torch.arange(2 * INPUT_LEN, dtype=torch.get_default_dtype()) / INPUT_LEN).reshape(SEQ_LEN, BATCH_SIZE, INPUT_SIZE)
y = torch.sin(2 * np.pi * X).reshape(SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)

rnn = torch.nn.LSTM(INPUT_SIZE, HIDDEN_SIZE)

optim = torch.optim.Adam(rnn.parameters())

for epoch in range(20):
    # the first cell will neither have a previous output or 
    # a previous hiddens state, set them to zero
    h0 = torch.zeros((NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE))
    c0 = torch.zeros((NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE))

    out, (h_out, c_out) = rnn(X, (h0, c0))

    optim.zero_grad()

    # Mean squared error loss
    loss = torch.tensor(y).sub(out).pow(torch.tensor(2)).sum().div(torch.tensor(SEQ_LEN*BATCH_SIZE))
    loss.backward()
    
    optim.step()
    
N_STEPS = 20 # number of time steps to predict

X = (torch.arange(2 * INPUT_LEN, dtype=torch.get_default_dtype()) / INPUT_LEN)
y = torch.cos(2 * np.pi * X)
y_init = y[:N_STEPS]

# start with initial values, rest will be predicted
y_pred = y_init
for i in range(len(X) - N_STEPS):
    inp = y_pred[i:i + N_STEPS].reshape(N_STEPS, INPUT_SIZE, HIDDEN_SIZE)
    out, (h_out, c_out) = rnn.forward(inp)
    # h_out is the output for t = seq_len
    y_pred = torch.cat((y_pred, h_out.flatten()))

# detach results
X = X.detach().numpy()
y = y.detach().numpy()
y_pred = y_pred.detach().numpy()
y_init = y_init.detach().numpy()

# plot result
plt.plot(X[N_STEPS:], y_pred[N_STEPS:], label="prediction")
plt.plot(X, y, label="ground truth")
plt.plot(X[:N_STEPS], y_init, label="initial sequence", color="red")
plt.legend(loc="upper left")
plt.show()

This code does not seem to train the model at all. I suspect there might be something wrong with the defined loss function (MSE). The loss stays around 0.6.

plot

I want the prediction (blue line) to correctly predict the ground truth (orange line).