Support for bidirectional_dynamic_rnn?

Has any work been done on supporting dynamic unrolling of inputs as in TF’s {bidirectional_}dynamic_rnn?

1 Like

In PyTorch, a dynamic RNN over a custom cell is a for loop. That is, the following two code snippets do the same thing (the first one is a simplified version of the implementation of tf.dynamic_rnn)

#TensorFlow (should be run once, during `__init__`)
cond = lambda i, h: i < tf.shape(words)[0]
cell = lambda i, h: rnn_unit(words[i], h)
i = 0
_, h = tf.while_loop(cond, cell, (i, h0))

#PyTorch (should be run for every batch, during `forward`)
h = h0
for word in words:
    h = rnn_unit(word, h)
1 Like

Thanks. The python/ops/ code is so involved / convoluted I thought that there must be something more going on that that.

@jekbradbury wouldn’t it be possible (and faster) to run the loop entirely on the GPU?

In pytorch, running unidirectional one-layer arbitrary cell is easy (as @jekbradbury showed in his snippet), it becomes more involved if you need bidirectional/stacked recurrent cells - you either have to replicate bidirectional/stacked mechanics from nn/_functions/, or add your cell all over the place in nn/_functions/ @csarofeen had a proposal to make that easier, but so far it went nowhere.
I don’t quite understand what you mean by running loop entirely on the GPU. All the data operations will be performed on the GPU if your input, hidden and weights are on the GPU, the overhead of python loop is negligible. The biggest performance problem with custom rnn cell is not loop overhead, it’s that pointwise operations are not fused, for a typical recurrent cell such as LSTM there will be 6-10 of them, and performance will be limited by launch latency of those 6-10 kernels.

Also, note that

  • you can use pack_padded_sequence to allow sequences of different length inside a minibatch
  • it’s perfectly fine to use minibatches with different sequence lengths in every call to LSTM
  • If your application allows that, using nn.LSTM instead of manually unrolled nn.LSTMCell you can easily observe a 10x speedups.

@elanmart I’m familiar with the call from the API docs, but I don’t see it used anywhere in the examples. The section of code I’m looking at is:

Can I just pack/pad the sentences in the minibatch and feed that to a BiLSTM?

No, because it looks like you have to apply attention. Look at for attention example.