How should I understand the output of LSTM

I am learning LSTM and GRU, but their outputs are confusing to me. Does their hidden mean the same thing? What is the cell state of LSTM? On the internet, cell state is said that there are very few changes, but when I search for the reason for the change, I cannot find the answer. When will the cell state change? I am writing code for an LSTM seq2seq model, and its encoder layer is like this. It does not use cell state and hidden at all, but it seems to be running normally. I always feel that it is incorrect. Can someone help me point it out?

import torch 
import torch.nn as nn
import random

class Encoder(nn.Module):
    def __init__(self,input_dim,emb_dim,hid_dim,n_layers,dropout):
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embeddings = nn.Embedding(input_dim,emb_dim)

        self.rnn = nn.LSTM(emb_dim,hid_dim,n_layers,dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self,src,src_length):
        # src : (Ts,bcsz)
        # src_length (bcsz)


        embedded = self.dropout(self.embeddings(src))
        # outputs:(Ts,bcsz,f)
        # hidden:(nlayers,bcsz,f)
        # cell: (nlayerss,bcsz,f)
        # assert outputs[-1] == hidden[-1]
        outputs,(hidden,cell) = self.rnn(embedded)

        # return (hidden,cell)

        # (bcsz,f)
        higher_hidden_accurate = torch.stack([
            outputs[t,i_batch] for i_batch ,t in enumerate(src_length-1)
        hidden_out = torch.zeros_like(hidden)
        cell_out = torch.zeros_like(cell)

        hidden_out[0] = higher_hidden_accurate
        return (hidden_out,cell_out)

class Decoder(nn.Module):
    def __init__(self,output_dim,emb_dim,hid_dim,n_layers,dropout):

        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embeddings = nn.Embedding(output_dim,emb_dim)
        self.rnn = nn.LSTM(emb_dim,hid_dim,n_layers,dropout=dropout)
        self.fc_out = nn.Linear(hid_dim,output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self,input,hidden,cell):
        # input : (bcsz)
        # hidden (nlayers,bcsz,f)
        # cell (nlayers,bcsz,f)

        # (1,bcsz,f)
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embeddings(input))

        # hidden:(nlayers,bcsz,f)
        # cell:(nlayers,bcsz,f)
        output,(hidden,cell) = self.rnn(embedded,(hidden,cell))

        #prediction :(bcsz,output_dim)
        prediction = self.fc_out(output.squeeze(0))

        return prediction,hidden,cell
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder,sos_value,eos_value) :
        self.encoder = encoder
        self.decoder = decoder

        self.sos_value = sos_value
        self.eos_value = eos_value

        assert encoder.hid_dim == decoder.hid_dim
        assert encoder.n_layers == decoder.n_layers

    def forward(self,src,src_len,target):
        # src (Ts,bcsz)
        # src_len( bcsz,)
        # target (Tt,bcsz)
        _,bcsz = src.shape
        Tt,_ = target.shape

        target_vocab_size = self.decoder.output_dim

        # tensor to store decoder outputs
        outputs = torch.zeros(Tt,bcsz,target_vocab_size).to(target.device)

        hidden,cell = self.encoder(src,src_len)

        input = target[0,:]
        for t in range(1,Tt):
            output,hidden,cell = self.decoder(input,hidden,cell)
            outputs[t] = output            
            input = target[t]

        return outputs
    def decode(self,src,src_length,max_devode_step= 20):
        # src(Ts,bcsz)
        # src_length(bcsz)

        _,bcsz = src.shape

        hidden,cell = self.encoder(src,src_length)
        input = torch.full((bcsz,),fill_value=self.sos_value,dtype = torch.long).to(src.device) 

        history = []

        eos_count = torch.zeros((bcsz,)).to(src.device)

        for _ in range(max_devode_step):
            # output:(bcsz,C)
            output,hidden ,cell = self.decoder(input,hidden,cell)
            # top1:(bcsz,)
            top1 = output.argmax(1)

            eos_count += (top1 == self.eos_value)

            if(eos_count != 0 ).all():break

            input = top1.long()

        if len(history) == 0:
            return torch.full((bcsz,1),self.eos_value,dtype=torch.long).to(src.device)
        # (bcsz,T)
        return torch.stack(history,dim=-1)

Hi Yinjun,

Does their hidden mean the same thing?

In short, no. Since the architectures of LSTM and GRU differ, their hidden states are also calculated differently.

What is the cell state of LSTM? When will the cell state change?

In theory, RNNs should be capable of modelling of long sequences by retaining contexts from long back in a sequence. In practice, it was found that it’s not the case and RNNs struggle to maintain contexts in longer sequences due to mathy reasons like gradient vanishing etc. The LSTM architecture was primarily deviced to solve this problem, and the Cell state is the means by which LSTMs preserve long term memory. Cell state, in turn, is controlled using various gates (input, output and forget gates), and it changes based on the input at a particular time step, the gates and the model params.

I recommend going through the LSTM architecture once.
From a read, your code does not seem to have any errors and it also runs fine acc to you. I am happy to help more with the logical correctness of the code if you could describe your modelling problem clearly in words.

1 Like

Thank you for your answer. I seem to be confused by their similar names. Furthermore, I feel that there are issues with my code entirely because they do not use cells and hide, which makes me very uneasy because I don’t know when to use them, when to change them, and how to change them

Cell states are usually not used for output calculation but hidden states are definitely used for that purpose. Again, if you describe your modelling task clearly, I’ll be able to help you with translating that into code.

Actually, I’m not quite sure when I should manually change the hidden state and cell state. The code I wrote was only for learning and understanding when to do something, but when I finished writing, I found that there was no use of hidden state and cell state. Can you give me some example code to tell me when to manually change the cell state and hidden state?

The official PyTorch LSTM tut might help.