RNNs: Understanding differences of feeding sequence as a whole VS step-by-step

Hello,

I’m a bit lost with understanding what PyTorch does internally in these two variations of passing a sequence through an RNN model:

Version 1:
out, hidden = model(sequence)

Version 2:

for element in sequence:
    out, hidden = model(element, hidden)

In my understanding, version 1 and 2 should be the same - I’m just manually unrolling it over time in version 2. Version 2 also allows us to implement stuff like modifying the hidden state before passing it to the next timestep (which is what I would like to do for my use case), and also to train it without teacher forcing (by passing in the last “out” instead of “element”). However, I noticed that version 1 trains much faster (about 8x on my machine).

Question 1: Is version 1 faster simply because e.g. it has less forward-calls, or does PyTorch internally somehow parallelise something across timesteps (which I don’t see how it could do that given that the RNN’s computations depend on the output of the previous timestep)?

Question 2: More generally, is version 2 indeed the way to go if I need to train without teacher forcing or if I need to modify “hidden” between timesteps?

Thank you for your help!

If you look at source code version 1 internally runs version 2.
There is a Base_RNN class or something like that. I cannot find it right now but at least it used to be like that. Anyway according to the theory it has to work like that. It cannot be paralelized .

1 Like

Thank you for your reply! What do you think causes the speed difference if V1 runs V2 internally?

So the stock LSTM will use CuDNN which has been heavily optimized but is inflexible.
But the JIT can make custom RNNs almost as fast: Some of this is discussed in the LSTM with JIT blogpost and we also have optimizations for the backward.
You can look at the benchmark code to get ideas how to use them.

Best regards

Thomas

1 Like

As I said it used to be like that, at least last time I checked source code but it seems it has been improved :slight_smile:
Thanks @tom Awesome info as always

1 Like