How can I replace FFN with LSTM?

Inside the Transformer encoder, there’s a FFN which I want to replace with LSTM to check performance improvements, how should I go about it?

The transformer inputs are of the form torch.Size([16, 12, 170, 152]) → (batch, timestamps, sensors, embedding_dim) but LSTM only works for 3D data so what would be the proceed?

Should I make the 4d data to 3d for LSTM by multiplying the indices 0 and 1 giving (batch * timestamps, sensors, embedding_dim) and then back to 4d for the rest of the model or is there keras.layers.TimeDistributed or try something else?

Hm, the expected input for the transformer is typically 3d: (batch_size, seq_len, embed_size), which is also the input and output for the FFN layer. So this would be perfectly valid input for a nn.LSTM with batch_first=True.

I’m not sure how you put 4d data into a transformer.

That being said, one of the motivation behind the transformer architecture was to get rid of recurrences to yield a better parallelization. So it’s not obvious why you want to bring them back :).

1 Like

in my case, the seq_len is a 2D tensor since its a spatio-temporal dataset, where the sensors are located at different spots and each collects data at periodic intervals so that adds another dimension making the input 4D.

The primary motivation was to understand the impact of the FFN layer to the overall accuracy and then to see the impact of replacing it with other neural architectures like LSTM and GRU.

So can you already train an basic transformer with this input without any changes to the FFN layer. And if so, how?

Yeah, the basic transformer is working and producing results. Here’s the link to the model code: https://github.com/XDZhelheim/STAEformer/blob/main/model/STAEformer.py

Ah, OK… you’re already have your custom transformer implementation. To be honest, I’m not really sure how transformer semantics transfers from 3d to 4d input data.

I think for most of the spatio-temporal tasks or video stuff, the transformer would have to deal with 4d input data. Would you suggest converting 4d input to 3d for transformers?