Dear community,
I’m a noob with Pytorch and I tried to fit the exponent of an equation with a custom activation fuction. However, I failed in defining it. Why?
import math
import torch
from torch.autograd import Variable
from torch import optim
from torch.nn.parameter import Parameter
class powerActivation(nn.Module):
def __init__(self):
super(powerActivation, self).__init__()
self.weight = Parameter(torch.Tensor(1, 1))
self.reset_parameters()
def reset_parameters(self):
self.weight.data.uniform_(1, 2)
def forward(self, x):
return x**self.weight
def build_model():
model = torch.nn.Sequential()
model.add_module("linear", powerActivation())
return model
def train(model, loss, optimizer, x, y):
x = Variable(x, requires_grad=False)
y = Variable(y, requires_grad=False)
# Reset gradient
optimizer.zero_grad()
# Forward
fx = model.forward(x.view(len(x), 1))
output = loss.forward(fx, y)
# Backward
output.backward()
# Update parameters
optimizer.step()
return output.data[0]
def main():
torch.manual_seed(42)
X = torch.linspace(2, 10, 101)
Y = X **2
model = build_model()
loss = torch.nn.MSELoss(size_average=True)
optimizer = optim.SGD(model.parameters(), lr=0.1)
batch_size = 10
for i in range(20):
cost = 0.
num_batches = len(X) // batch_size
for k in range(num_batches):
start, end = k * batch_size, (k + 1) * batch_size
cost += train(model, loss, optimizer, X[start:end], Y[start:end])
print("Epoch = %d, cost = %s" % (i + 1, cost / num_batches))
w = next(model.parameters()).data # model has only one parameter
print("w = %.2f" % w.numpy()) # will be approximately 2
print(model(Variable(X)).data)
print(list(zip(X,Y)))
main()
the result is:
Epoch = 1, cost = 2439.914280539751
Epoch = 2, cost = 2449.544422531128
Epoch = 3, cost = 2449.544422531128
Epoch = 4, cost = 2449.544422531128
Epoch = 5, cost = 2449.544422531128
Epoch = 6, cost = 2449.544422531128
Epoch = 7, cost = 2449.544422531128
Epoch = 8, cost = 2449.544422531128
Epoch = 9, cost = 2449.544422531128
Epoch = 10, cost = 2449.544422531128
Epoch = 11, cost = 2449.544422531128
Epoch = 12, cost = 2449.544422531128
Epoch = 13, cost = 2449.544422531128
Epoch = 14, cost = 2449.544422531128
Epoch = 15, cost = 2449.544422531128
Epoch = 16, cost = 2449.544422531128
Epoch = 17, cost = 2449.544422531128
Epoch = 18, cost = 2449.544422531128
Epoch = 19, cost = 2449.544422531128
Epoch = 20, cost = 2449.544422531128
w = -21.63
Thank you