How to customise LSTM cell and integrate that to work with auto grad. Suppose part of the weights in the LSTM cells are to be described by some formulas and to be updated given that formula? For example in this snippet I try a cubic / quadratic approximation of the input weights but when I print, I see that weights never change. Here is a minimal example of what I try to do
class CubicLSTM(nn.Module):
def __init__(self, input_size=1, hidden_size=64, num_layers=1, output_size=1):
super(CubicLSTM, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
# Initialize 'a' parameters with 6.5e-5 and 'b' parameters with 0
self.a = nn.Parameter(torch.tensor(2))
self.fc1 = nn.Linear(hidden_size, output_size)
def _update_lstm_weights(self):
"""Update the LSTM input weights (weight_ih_l0) based on cubic-quadratic forms."""
z = torch.arange(self.hidden_size)
W = self.a * z**3
with torch.no_grad():
self.lstm.weight_ih_l0.copy_(W)
# self.lstm.weight_ih_l0.copy_(W) #uncomment to avoid no grad
def forward(self, x):
self._update_lstm_weights()
out, (hn, cn) = self.lstm(x)
out = self.fc1(hn[-1])
return torch.tanh(out)
I also tried to add those weights directly in the optimiser but I see they are never updated.