Modify hidden state of LSTM and output

Hello everyone,
I want to train a LSTM, but i have some modifications to do to the calculations. I want to perform some calculations on the hidden state, before it gets passed on to the next calculation for the next element in the sequence. Lets say:

h' = o * \tanh(c')

But i now want to take this h, pass it through a fully connected layer, do some calculations with it to get another value A (which i want to be the actual output of the cell, instead of h) and then pass the modified h to the next calculation. I have tried modifying the nn.LSTMCell but i am not sure if this is even the right place to start. I have trained a net before using the plain nn.LSTM() class and this worked well.

Maybe someone has a recommendation on where and what i should change to achieve that. Thank you in advance

nn.LSTMCell() is just a basic version of nn.LSTM where num_layers = 1 and bidircectional = False.

With that said, you can make changes to the hidden state(hx) or cell state(cx) by using a sequence length of 1 at each time step(assuming you wish to make your calculation at every time step).

rnn = nn.LSTMCell(10, 20)  # (input_size, hidden_size)
input = torch.randn(2, 3, 10)  # (time_steps, batch, input_size)
hx = torch.randn(3, 20)  # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
for i in range(input.size()[0]):
    hx, cx = rnn(input[i], (hx.detach(), cx.detach()))
    hx = hx*2 #<<<< some calculation >>>>
output = torch.stack(output, dim=0)

Note, you might not need the .detach() to prevent the graph from getting captured through time steps under more recent torch versions, but I include them by habit and they don’t hurt anything.

Thank you for your answer, this definitely makes sense, and i think this solves the problem in general, but:

I have tried this already and had the problem that this was quite slow since i iterate over all the samples from the batch.
When i use the LSTM in a normal setup, it seems that the whole batch is processed with one call. So is there a way to modify the function that really does the computation on the whole batch? I hope its clear what i mean, i try to show an example:

For standard LSTM with batch of 100:

output, h_c = self.lstm(x.squeeze(), (h.detach(), c.detach()))

Gives output of batch size

Now what i am trying to do is have the same call, but internally change how the hidden state is calculated. Thanks

I can’t tell from your code if you mean batch size or sequence length. You can make the batch size any size your RAM can handle. The batch dimension is NOT processed as a sequence. Since the calculations are independent of each other, it is handled asynchronously.

But if you mean sequence length(which is dim = 0, unless you set batch_first = True), as far as I know, the underlying LSTM code in C++ iterates over the sequence length. It is synchronous.

Thanks for the reply,
maybe i am not understanding the terms here precisely. From my understanding, giving the lstm() function a tensor of length 100 would mean, the length 100 output is calculated by unfolding the lstm 100 times, internally passing the hidden states for each sample. Using the for loop as you proposed, i do the same but not giving the lstm the whole sequence at once, but sample by sample and manually passing the hidden state as an argument?
Sorry if this is confusing or completely wrong, I am familiar how LSTMs work, but not exactly how they behave in pytorch.

Please refer here to input, h_0 and c_0:

Link: LSTM — PyTorch 2.0 documentation

Depending on what your dims are will determine whether it’s the batch size or sequence length. For example:

sequence_length = 100
batch_size = 200
features_dim = 128
lstm_input = torch.rand((sequence_length, batch_size, features_dim))

# or with no batch
lstm_input = torch.rand((sequence_length, features_dim))