Equivalent of Keras's `mask` arg in the forward pass of an RNN

I’m looking for the equivalent of the mask arg for the forward pass from Keras’s RNN implementation in PyTorch. It seems that PyTorch has the pack_padded_sequence function to skip over the end of padded sequences, but Keras has some additional functionality that I need, which is the ability to skip RNN computations at arbitrary time points and propagate the previous time step to the next time step according to the mask arg.

1 Like

Let me know if you wind a solution or workaround to this. I am actually struggling with exactly the same.