I’m trying to build an LSTM AE with multiple layers and bidirectional, but I’m getting confused with the dimensions. A single layer AE w/o bidireciton works but if I’m adding layers or bidirectional=True I have to recalculate the dimension for each cell - is there a straight forward approach or how do you calculate the input/output dimensions of each cell?
I’m don’t quite understand your question regarding the recalculation of the dimensions of each cell.
Have a look at my code of an LSTM/GRU-based autoencoder to see if that helps. The relevant snippets is probably the _flatten_hidden() method of the encoder and the _unflatten_hidden() method of the decoder.
My Enc/Dec consists of two cells with decreasing dimenions. Now I’d like to add bidirecitonal functionality and more than 1 layer for each cell → the dimensions of my input changes.
Strictly speaking you don’t have LSTM cells but LSTM layers – at least I think your code was using nn.LSTM and not nn.LSTMCell.
As far as I understand an LSTM with 2 layers is the same as having 2 LSTM layers and using the output of the first as the input of the second. I assume there are some difference in the detail, like the initialization of the hidden state. But over all it should be the same. In short, I would use only one LSTM, just using 2 or more layers.
Using view() or reshape() is not intrinsically wrong, one just have to be careful to do it right. For example, if you look at the _flatten() method of the decoder:
I wouldn’t get an error since the shape is still correct. However, the data is now kind of scrambled; see my example in that old post.
Technically, you only need to flatten the tensor (e.g., the last hidden state of the LSTM) if you intend to push it through some additional linear layers before giving it to the decoder. Without that – and assuming your encoder and decoder LSTM have the same setup (same number of layers, same number of hidden dimension, both uni-/bi-directional) – then you can simply set the initial hidden state of the decoder as the last hidden state of the decoder.