I am trying to train a neural network where I have a custom activation function. However, the loss function remains constant through all epochs. When I printed the weight.grad for the linear layer, it returns None, which means that somewhere, a gradient is not being computed. Here is my NN class code:
class OptNet(nn.Module):
def __init__(self):
super(OptNet,self).__init__()
m = 2
self.fc1 = nn.Linear(m*m, 2*m*m, bias=False)
with torch.no_grad():
self.fc1.weight.data = torch.tensor([[1.1,0,0,0],[0,1.3,0,0],[0,0,1.2,0],[0,0,0,1.3],[-1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,-1]])
self.myact = MyActivationFunction(m)
def forward(self, x):
m = 2
x = self.fc1(x)
return self.myact(x)
And here is my activation function code (it is meant to solve a quadratic program with one of the parameters being reshaped x):
class MyActivationFunction(nn.Module):
def __init__(self, m=2):
super(MyActivationFunction, self).__init__()
self.m = m
def forward(self, x):
Q = torch.tensor([[1.3, 0.3], [0.3, 1.7]], dtype=torch.float32)
q = torch.tensor([1,-1], dtype=torch.float32)
A = torch.reshape(x, (4,2))
b = torch.ones(2*self.m, dtype=torch.float32)
e = torch.tensor([], dtype=torch.float32)
return QPFunction(verbose=-1)(Q, q, A, b, e, e)
And finally, here is the code used to actually train the network. The dataMatrix and solutions tensors are the training data set, which I have used for other purposes before so I am fairly confident these are generated correctly.
net = OptNet()
optimizer = optim.SGD(net.parameters(), lr=0.5)
loss_func = nn.MSELoss()
maxepochs = 100
lossarray = numpy.zeros(maxepochs)
for epoch in range(maxepochs):
optimizer.zero_grad()
predictions = torch.zeros(trainPoints,m)
for eachpoint in range(trainPoints):
predictions[eachpoint] = net(Variable(dataMatrix[eachpoint]))
loss = Variable(loss_func(predictions, solutions), requires_grad=True)
loss.backward()
print(net.fc1.weight.grad) #this prints None
optimizer.step()
lossarray[epoch] = loss
if epoch%5 == 1:
print(loss) #this is always constant
I figure it must be something to do with the ActivationFunction which uses QPFunction from the qpth library, but this library has differentiation implemented. Any help would be greatly appreciated, thanks!