Is there a good solution for this problem?

Hi, Now i need to do a serial temporal feature extraction,and the input is a sequence shown as following.

each feature map is with the shape of (25,25,1) where 1 is the number of the channel. And we have T frames in a sequence,so the total shape of a sequence is (T,25,25,1). What i want to do is that process LSTM to each pixel in the feature map along the time axis and output a single map with shape of (25,25,1) as the yellow one shown. Using pseudo-code to describe the process.

pixels_features = []
for each_row in range(25):
  for each_col in range(25):
    pixels_feature,(hn,cn) = LSTM(pixels[:,each_row, each_col,0], (h0,c0))

The code can work in pytorch but it is very time-consuming for it is a serial computing. I wonder if there is a parallel solution for this problem? I don’t know if i describe the problem well, please let me know if you have any question about my problem.
Thansk for any suggestion :slight_smile:

This problem is inherently serial if you use LSTM/GRU.
There is an alternative which is claimed to be much faster: SRU.

Training RNNs as Fast as CNNs (
PyTorch code:

Thanks for your rapid reply. I think the forward process is possible to be parallel however it is hard to be parallel in backward to update the params in LSTM. Is it right?

I think so.
Truncated backpropagation through time is used as a remedy.
I do not know of a better solution.

Thanks for your suggestions, i will take SRU for a try:)