I am learning LSTM and GRU, but their outputs are confusing to me. Does their hidden mean the same thing? What is the cell state of LSTM? On the internet, cell state is said that there are very few changes, but when I search for the reason for the change, I cannot find the answer. When will the cell state change? I am writing code for an LSTM seq2seq model, and its encoder layer is like this. It does not use cell state and hidden at all, but it seems to be running normally. I always feel that it is incorrect. Can someone help me point it out?
mycode:
import torch
import torch.nn as nn
import random
class Encoder(nn.Module):
def __init__(self,input_dim,emb_dim,hid_dim,n_layers,dropout):
super(Encoder,self).__init__()
self.hid_dim = hid_dim
self.n_layers = n_layers
self.embeddings = nn.Embedding(input_dim,emb_dim)
self.rnn = nn.LSTM(emb_dim,hid_dim,n_layers,dropout=dropout)
self.dropout = nn.Dropout(dropout)
def forward(self,src,src_length):
# src : (Ts,bcsz)
# src_length (bcsz)
#embedded:(Ts,csz,f)
embedded = self.dropout(self.embeddings(src))
# outputs:(Ts,bcsz,f)
# hidden:(nlayers,bcsz,f)
# cell: (nlayerss,bcsz,f)
# assert outputs[-1] == hidden[-1]
outputs,(hidden,cell) = self.rnn(embedded)
# return (hidden,cell)
# (bcsz,f)
higher_hidden_accurate = torch.stack([
outputs[t,i_batch] for i_batch ,t in enumerate(src_length-1)
],dim=0)
hidden_out = torch.zeros_like(hidden)
cell_out = torch.zeros_like(cell)
hidden_out[0] = higher_hidden_accurate
return (hidden_out,cell_out)
class Decoder(nn.Module):
def __init__(self,output_dim,emb_dim,hid_dim,n_layers,dropout):
super(Decoder,self).__init__()
self.output_dim = output_dim
self.hid_dim = hid_dim
self.n_layers = n_layers
self.embeddings = nn.Embedding(output_dim,emb_dim)
self.rnn = nn.LSTM(emb_dim,hid_dim,n_layers,dropout=dropout)
self.fc_out = nn.Linear(hid_dim,output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self,input,hidden,cell):
# input : (bcsz)
# hidden (nlayers,bcsz,f)
# cell (nlayers,bcsz,f)
# (1,bcsz,f)
input = input.unsqueeze(0)
embedded = self.dropout(self.embeddings(input))
#ouput:(1,bcsz,f)
# hidden:(nlayers,bcsz,f)
# cell:(nlayers,bcsz,f)
output,(hidden,cell) = self.rnn(embedded,(hidden,cell))
#prediction :(bcsz,output_dim)
prediction = self.fc_out(output.squeeze(0))
return prediction,hidden,cell
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder,sos_value,eos_value) :
super(Seq2Seq,self).__init__()
self.encoder = encoder
self.decoder = decoder
self.sos_value = sos_value
self.eos_value = eos_value
assert encoder.hid_dim == decoder.hid_dim
assert encoder.n_layers == decoder.n_layers
def forward(self,src,src_len,target):
# src (Ts,bcsz)
# src_len( bcsz,)
# target (Tt,bcsz)
_,bcsz = src.shape
Tt,_ = target.shape
target_vocab_size = self.decoder.output_dim
# tensor to store decoder outputs
outputs = torch.zeros(Tt,bcsz,target_vocab_size).to(target.device)
hidden,cell = self.encoder(src,src_len)
input = target[0,:]
for t in range(1,Tt):
output,hidden,cell = self.decoder(input,hidden,cell)
outputs[t] = output
input = target[t]
return outputs
def decode(self,src,src_length,max_devode_step= 20):
# src(Ts,bcsz)
# src_length(bcsz)
_,bcsz = src.shape
hidden,cell = self.encoder(src,src_length)
input = torch.full((bcsz,),fill_value=self.sos_value,dtype = torch.long).to(src.device)
history = []
eos_count = torch.zeros((bcsz,)).to(src.device)
for _ in range(max_devode_step):
# output:(bcsz,C)
output,hidden ,cell = self.decoder(input,hidden,cell)
# top1:(bcsz,)
top1 = output.argmax(1)
eos_count += (top1 == self.eos_value)
if(eos_count != 0 ).all():break
history.append(top1)
input = top1.long()
if len(history) == 0:
return torch.full((bcsz,1),self.eos_value,dtype=torch.long).to(src.device)
# (bcsz,T)
return torch.stack(history,dim=-1)