I’m trying to learn how to use Pytorch and wanted to start with a really simple test.

So I’ve created a function f(x) = np.sin(x * 50) * 500 + x * 2000

and I want a Feedforward Neural Network to approximate this in the range [0,1].

I’ve created 5000 equally distributed samples and just trained the network on all samples. I do not care if it overfits, its just for testing.

From my tests it seems like my network approximates just a linear function or the mean of my samples…

I’ve tried to play with the amount of layers but even a network with 3 layers with around 100 neurons wasn’t able to find a good solution.

So I’m pretty sure I’m using Pytorch incorrectly but I cannot find my error.

Below you can find my SourceCode, I think the most important part are the two for loops which contain the training procedure.

## Code

```
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# the target function
def func(x):
return np.sin(x * 50) * 500 + x * 2000
# plot the target function
x_steps = np.linspace(0, 1, 5000)
y_steps = func(x_steps)
plt.plot(x_steps, y_steps)
# Create the model, optimizer and the loss function
model = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1))
optimizer = optim.Adam(model.parameters(), 0.1)
loss_func = nn.MSELoss()
# Create random index "batches"
indexes = np.array(list(range(len(x_steps))), dtype=np.int)
np.random.shuffle(indexes)
batches = torch.split(torch.from_numpy(indexes), 64)
# Training
for epoch in range(100):
for batch in batches:
x_batch = x_steps[batch]
y_batch = y_steps[batch]
prediction = model(torch.Tensor(x_batch.reshape((-1,1))))
loss = loss_func(prediction, torch.Tensor(y_batch.reshape((-1,1))))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss.detach().numpy())
# Test if the function was approximated
predictions = model(torch.Tensor(np.reshape(x_steps, (-1,1))))
plt.plot(x_steps, y_steps)
plt.show()
plt.plot(x_steps, predictions.detach().numpy())
plt.show()
```