I’m trying to learn how to use Pytorch and wanted to start with a really simple test.
So I’ve created a function f(x) = np.sin(x * 50) * 500 + x * 2000
and I want a Feedforward Neural Network to approximate this in the range [0,1].
I’ve created 5000 equally distributed samples and just trained the network on all samples. I do not care if it overfits, its just for testing.
From my tests it seems like my network approximates just a linear function or the mean of my samples…
I’ve tried to play with the amount of layers but even a network with 3 layers with around 100 neurons wasn’t able to find a good solution.
So I’m pretty sure I’m using Pytorch incorrectly but I cannot find my error.
Below you can find my SourceCode, I think the most important part are the two for loops which contain the training procedure.
import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt # the target function def func(x): return np.sin(x * 50) * 500 + x * 2000 # plot the target function x_steps = np.linspace(0, 1, 5000) y_steps = func(x_steps) plt.plot(x_steps, y_steps) # Create the model, optimizer and the loss function model = nn.Sequential( nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 1)) optimizer = optim.Adam(model.parameters(), 0.1) loss_func = nn.MSELoss() # Create random index "batches" indexes = np.array(list(range(len(x_steps))), dtype=np.int) np.random.shuffle(indexes) batches = torch.split(torch.from_numpy(indexes), 64) # Training for epoch in range(100): for batch in batches: x_batch = x_steps[batch] y_batch = y_steps[batch] prediction = model(torch.Tensor(x_batch.reshape((-1,1)))) loss = loss_func(prediction, torch.Tensor(y_batch.reshape((-1,1)))) optimizer.zero_grad() loss.backward() optimizer.step() print(loss.detach().numpy()) # Test if the function was approximated predictions = model(torch.Tensor(np.reshape(x_steps, (-1,1)))) plt.plot(x_steps, y_steps) plt.show() plt.plot(x_steps, predictions.detach().numpy()) plt.show()