I am changing from TF/Keras to PyTorch. To create a recurrent network with a custom cell, TF provides the handy function ’ tf.keras.layers.TimeDistributed’ that handles the sequence for you and applies your arbitrary cell to each time step.
Trying to implement something similar in PyTorch, I found code examples (like https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html) where a simple for-loop is used. While this works, I was wondering if there is a ‘better’ way to do this, or if this is already optimal performance wise. Coming from TF, I am sure that such a simple and elegant solution with the for-loop can be optimal, but PyTorch has surprised me already many times.