How to implement custom recurrent model without iterating over time?


I want to implement the Kalman Filter in Pytorch. For that I have written a custom KalmanCell(nn.Module) that performs a Kalman update/predict step for one time step.
I want to integrate the Cell in a module that takes as inputs a time series and outputs the filtered/noise-free time series.

My question is, how do I implement the module that calls KalmanCell without iterating over all time steps of the input time series?

For now I have a loop over the length of the time series (T).

for t in range(self.T):
      x_post, P_post, x_prior, P_prior = self._KalmanCell(x_prior, P_prior, input[:,t])

Is there a way to get rid of the loop? Can I inherit from RNN class? If so how would I write this?

For now I only saw examples with loop over time such as this time series example from Pytorch.

I would appreciate any suggestions. Thanks!