How extract all hidden state outputs instead of last hidden state output in LSTM?

I am new to pytorch . I am trying to grasp how the input and output shapes work in encoder - decoder model. while implementing attention model , i am stuck here where i need to get all hidden state outputs …

My code for decoder
class DecoderLSTM(nn.Module):

def __init__(self,embed_size, vocab_size, hidden_size):
    super(DecoderLSTM, self).__init__()
    self.embed_size = embed_size
    self.vocab_size = vocab_size
    self.word_embeddings = nn.Embedding(vocab_size, embed_size)
    self.decoderlstm = nn.LSTM(input_size = embed_size,
                               hidden_size = hidden_size,
                               num_layers = 1,
                               batch_first = True,
                               bias =  True)
    self.fc_layer = nn.Linear(hidden_size, vocab_size)
    
    self.attention = BahdanauAttention(hidden_size)
    
def init_hidden(self, enc_hidden):
    enc_h_t, enc_c_t = enc_hidden
    h_t = torch.tensor(enc_h_t).requires_grad_()
    c_t = torch.tensor(enc_c_t).requires_grad_()
    return (h_t,c_t)

def forward(self, enc_out, enc_hidden, captions):
    
    (h_t,c_t) = self.init_hidden(enc_hidden)
    
    embeds = self.word_embeddings(captions)
    inputs = torch.tensor(embeds)
    print('Embed Tensor Shape:',inputs.shape)
    dec_out, (dec_h_t, dec_c_t) = self.decoderlstm(inputs, (h_t.detach(), c_t.detach()))
    print('dec_h_t :',dec_h_t.shape)
    print('dec_c_t :',dec_c_t.shape)
    attention_out = self.attention(enc_out, dec_out)
    outputs = self.fc_layer(attention_out)
    return outputs

instead of getting all hidden state ouputs, i am getting only last state output (num_layer, batch_size, hidden_state )

How to get all hidden state outputs ?

You don’t get the cell state (h_t, c_t) from the LSTM for intermediates.
Thus you would want to loop over t yourself, using either LSTM with sequence length of 1 or LSTMCell.

Best regards

Thomas

1 Like

your can try to run all of your hidden layers in your forward function. here’s the sudo code:
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv2(x))

and then apply a for loop on the output of every layer and iterate it up to the number of layers you have and keep displaying all of he layers.

1 Like