RNN/LSTM Batch Training Question


#1

I’m wondering that if when training in batch, is there anyway to avoid the fact that the RNN will take the hidden output of the previous sequence as the input to the first timestep in the next sequence? This is a relationship that shouldn’t be learned because the time series won’t necessarily be in order in my case.


#2

i’m not an expert on this but i think that when you train a batch it all gets forwarded at the same time and the hidden states of one sequence in the batch is not used as the input to the next sequence. that’s why the hidden states are of size “lstm_layers X batch size X hidden size”


(Novak) #3

It’s not immediately clear what it is that you to do, or what you’re already doing.

Are you talking about training a batch (or mini-batch) in an off-line sense by a single call to the forward function of the model? If that is what you are doing, then I can say with very high certainty that the individual sequences are processed in parallel and their time evolution is independent.

In the absence of better tutorials than I have been able to find, it is highly instructive to code up a toy model of an RNN-- something like a batch size of 4, a sequence length of 5, and an input width of 2. This toy model is small enough that you can design your input batch tensor by hand in order to prove a point to yourself.

In this case, just code up all 4 sequences to be the same sequence. The same random sequence, even. Then you can inspect the output of this and see that all four output sequences are the same. This would not be the case if the final hidden output of the first sequence were fed back into the first hidden state for the second sequence.

Finally, remember that you have the ability to define what happens to the hidden states in your forward method. I do not think it is required here, though, so I will not lengthen an already long post.


(Thomas V) #4

The hidden state of the sequence will not be kept in the RNN/LSTM/… modules - you can either pass it in or it will be set to 0s for each invocation. Only the weights are learned.

Best regards

Thomas