Customized RNN step functions

Hi!

I’m new to this area and apologize if I’m asking something obvious, or asking in the wrong place. Please feel free to point me elsewhere!

Is there an equivalent of tf.keras.backend.rnn in pytorch? Something that takes a step_function as input, but automates the loop over a time dimension to make the computation more streamlined? Right now I’m just using a regular loop, and tracing the operation. It’s fast, but running similar code with tf.keras.backend.rnn is still substantially faster. The closest functions I can find all implement a specific step function (e.g., an LSTM layer), and don’t allow for a user defined step function.

Thanks very much!