Bidirectional encoder and decoder with 2 layers in Seq2Seq model

The output of my encoder is assume (2, 64, 256) so 2 layers of LSTM with a hidden size of 128 and I concatenate both directions. My question is, Say I initialized the decoder LSTM with same setup as encoder i.e., 2 layers bidirectional and hidden_dim of 256. Now the decoder expects a 4, 64, 256 tensor. How to init the tensor. Should I just keep the layers and directions separate and pass as it is or Is there a correct way to do this ? Each row in the encoder hidden tensor which is of shape 2, 64, 256 is the final hidden state of each layer in encoder. Doesn’t the decoder only expect the last hidden_state of encoder which means should I only pass -1, 64, 256 from encoder and initialize the rest with zeros ?

see with this code is your problem solved?

import torch
import torch.nn as nn

encoder_hidden_state = torch.randn(4, 64, 128)  # (num_layers * num_directions, batch_size, hidden_size)
encoder_cell_state = torch.randn(4, 64, 128)

num_layers = 2
num_directions = 2
hidden_size = 128
batch_size = 64

decoder_initial_h_0 = encoder_hidden_state  # Shape should be (num_layers * num_directions, batch_size, hidden_size)
decoder_initial_c_0 = encoder_cell_state    # Same shape as h_0

decoder_lstm = nn.LSTM(input_size=256, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True)

decoder_input = torch.randn(10, batch_size, 256)  # Assuming 10 time steps

decoder_output, (h_n, c_n) = decoder_lstm(decoder_input, (decoder_initial_h_0, decoder_initial_c_0))

print(f"Decoder output shape: {decoder_output.shape}")
print(f"Decoder hidden state shape: {h_n.shape}")
print(f"Decoder cell state shape: {c_n.shape}")

Thanks for the reply farshid, The way you described is passing the hidden state of encoder as it is to the decoder. I was wondering if it should be done in the way you mentioned. I feel like passing the forward final hidden state of encoder to forward decoder and backward to backward is not the correct way. But, I might be wrong. I wanted to know if there’s a correct way to do this.