I was trying to implement a model these last few days and I stumbled across problem that I can not solve:
Here is the newtork (simplified)
The blue part has its
requires_grad set to False, and I manually update the weight as:
self.pred.weight.data = self.pred.weight.data + 0.1 * modulation * deltaW[0,0].T
modulation = torch.mean(out_modulation).
To make it clear I run a simulation of N steps and every time step I update the weights manually of the online layer.
After all these steps I want to minimize the overall error by doing a backprop (autograd) over the red area. where the error is the sum of the prediction error in blue area (i.e. the error does not explicitly depend on the modulation population).
The problem is since the error does not depend explicilty of the red area, pytorch can not estimate a gradient. Is there a way to make the graph understand that the manual updates is dependent of the modulation population ?
full forward code:
` def forward(self, previous, observation,hidden_modulation):
prediction = self.pred(previous) #linear unit prediction_error = observation - prediction out_modulation, hidden_modulation = self.surprise(torch.abs(prediction_error), hidden_modulation) #LSTM modulation = torch.mean(torch.sigmoid(out_modulation)) deltaW = ( torch.einsum( "bij,bik->bijk", previous, prediction_error)) self.pred.weight.data = self.pred.weight.data + modulation * deltaW[0,0].T #online update self.pred.weight.data[self.pred.weight.data < 0] = 0 return prediction_error, hidden_pred, hidden_surprise,torch.relu(prediction),modulation`