I am trying to train a neural network where I have a custom activation function. However, the loss function remains constant through all epochs. When I printed the weight.grad for the linear layer, it returns None, which means that somewhere, a gradient is not being computed. Here is my NN class code:
class OptNet(nn.Module): def __init__(self): super(OptNet,self).__init__() m = 2 self.fc1 = nn.Linear(m*m, 2*m*m, bias=False) with torch.no_grad(): self.fc1.weight.data = torch.tensor([[1.1,0,0,0],[0,1.3,0,0],[0,0,1.2,0],[0,0,0,1.3],[-1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,-1]]) self.myact = MyActivationFunction(m) def forward(self, x): m = 2 x = self.fc1(x) return self.myact(x)
And here is my activation function code (it is meant to solve a quadratic program with one of the parameters being reshaped x):
class MyActivationFunction(nn.Module): def __init__(self, m=2): super(MyActivationFunction, self).__init__() self.m = m def forward(self, x): Q = torch.tensor([[1.3, 0.3], [0.3, 1.7]], dtype=torch.float32) q = torch.tensor([1,-1], dtype=torch.float32) A = torch.reshape(x, (4,2)) b = torch.ones(2*self.m, dtype=torch.float32) e = torch.tensor(, dtype=torch.float32) return QPFunction(verbose=-1)(Q, q, A, b, e, e)
And finally, here is the code used to actually train the network. The dataMatrix and solutions tensors are the training data set, which I have used for other purposes before so I am fairly confident these are generated correctly.
net = OptNet() optimizer = optim.SGD(net.parameters(), lr=0.5) loss_func = nn.MSELoss() maxepochs = 100 lossarray = numpy.zeros(maxepochs) for epoch in range(maxepochs): optimizer.zero_grad() predictions = torch.zeros(trainPoints,m) for eachpoint in range(trainPoints): predictions[eachpoint] = net(Variable(dataMatrix[eachpoint])) loss = Variable(loss_func(predictions, solutions), requires_grad=True) loss.backward() print(net.fc1.weight.grad) #this prints None optimizer.step() lossarray[epoch] = loss if epoch%5 == 1: print(loss) #this is always constant
I figure it must be something to do with the ActivationFunction which uses QPFunction from the qpth library, but this library has differentiation implemented. Any help would be greatly appreciated, thanks!