Fit a custom function parameters

Hi,

I’m trying to do something very basic but i’m not quite sure if it’s possible with pytorch (and how to do it).

I wanna fit a custom function to a series of data in order to get the fitting parameters from it (ofc, if it converges).

I wanna do something like this… let’s say i have a series of data x = [x1,x2,x3,… …,xn], y = [y1,y2,y3,… …,yn] and my hypothesis is that it can be interpolated by a function f(x) with coefs. let’s say… A, B and C, let’s say a polynomial → y = Ax^2 + Bx + C

How would a general model for this be? My guess is something like this, but not sure “how to tell pytorch” which vars to vary in order to fit it:

class NN():
    def __init__(self):
        super().__init__()
        self.fx = A*x**2 + B*x +C

    def forward(self,x):
        y0 = self.fx(x)

Any advice if that’s the right way to do it?

You would define trainable parameters in the __init__ method and use them in the forward. This tutorial might be helpful.
For your example, this should work:

class NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.A = nn.Parameter(torch.randn(1)) 
        self.B = nn.Parameter(torch.randn(1))
        self.C = nn.Parameter(torch.randn(1))

    def forward(self,x):
        out = self.A * x**2 + self.B * x + self.C
        return out

# create data    
x = torch.linspace(0, 10, 1000)
A, B, C = 2, 3, 4
target = A * x**2 + B * x + C

plt.plot(x.numpy(), target.numpy())

# model training
model = NN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

nb_epochs = 25000
for epoch in range(nb_epochs):
    optimizer.zero_grad()
    out = model(x)
    loss = criterion(out, target)
    loss.backward()
    optimizer.step()
    print("epoch: {}, loss: {:.3f}\nA: {}, B: {}, C: {}".format(
        epoch, loss.item(), model.A.item(), model.B.item(), model.C.item()))

# ...
# epoch: 24999, loss: 0.000
# A: 2.0, B: 3.0, C: 4.0

Perfect!!! That’s what i was missing, thanks a lot!!

How do i get the values of the parameters after training?

The print statement directly access the trained parameters and shows their values during the training.

1 Like