I need some help. I’m having some problems setting up an basic LSTM autoencoder (without attention or anything fancy).
This implementation follows a paper that uses this implementation:
Encoder: Standard LSTM layer. Input sequence is encoded in the final hidden state.
Decoder: Reconstruct the sequence one element at a time, starting with the last element x[N]
.
Decoder algorithm is as follows for a sequence of length N
:
- Get Decoder initial hidden state `hs[N]: Just use encoder final hidden state.
- Reconstruct last element in the sequence:
x[N]= w.dot(hs[N]) + b
. - Same pattern for other elements:
x[i]= w.dot(hs[i]) + b
- use
x[i]
andhs[i]
as inputs to LSTM to getx[i-1]
andhs[i-1]
The problem is that at each time step, x[i]
and hs[i]
must be the same shape in order to be input to an LSTM cell, but they are not. That’s because x
is the shape of the original sequence data, where hs
is the shape of the embedded/encoded vector. There’s no way to make them the same…
What gives?
Encoder:
class SeqEncoderLSTM(nn.Module):
def __init__(self, n_features, hidden_size):
super(SeqEncoderLSTM, self).__init__()
self.lstm = nn.LSTM(
n_features,
hidden_size,
batch_first=True)
def forward(self, x):
_, hs = self.lstm(x)
return hs
Decoder:
class SeqDecoderLSTM(nn.Module):
def __init__(self, enc_hs, emb_size, n_features, seq_len):
super(SeqDecoderLSTM, self).__init__()
self.seq_len = seq_len
self.cell = nn.LSTMCell(emb_size, emb_size)
self.dense = nn.Linear(emb_size, n_features)
def forward(self,x, hs_0):
# add each point to the sequence as it's reconstructed
x = torch.tensor([])
# Initial hidden state of decoder is final hidden state of encoder
hs_i, cs_i = hs_0
# reconstruct first (last) element
x_i = self.dense(hs_i)
x = torch.cat([x, x_i])
# reconstruct remaining elements
for i in range(1, self.seq_len):
print("last x: ", x.shape)
print("hs shape: ", hs_i.shape)
print("cs shape: ", cs_i.shape)
hs_i, cs_i = self.cell(x_i, (hs_i, cs_i))
x_i = self.dense(hs)
torch.cat([x, x_i])
Bringing it all together:
class LSTMEncoderDecoder(nn.Module):
def __init__(self, n_features, emb_size, seq_len):
super(LSTMEncoderDecoder, self).__init__()
self.n_features = n_features
self.hidden_size = emb_size
self.encoder = SeqEncoderLSTM(n_features, emb_size)
self.decoder = SeqDecoderLSTM(emb_size, emb_size, n_features, seq_len)
def forward(self, x):
hs = self.encoder(x)
hs = tuple([h.squeeze(0) for h in hs])
out = self.decoder(x, hs)
return out