I am trying to create a simple LSTM autoencoder.
More precisely I want to take a sequence of vectors, each of size input_dim
, and produce an embedded representation of size latent_dim
via an LSTM. From this I would like to decode this embedded representation via another LSTM, (hopefully) reproducing the input series of vectors.
Here is my definition for the encoder and decoder
self.encoder = nn.LSTM(input_dim, latent_dim)
self.encoder_hidden = (autograd.Variable(torch.zeros(1, 1, self.latent_dim)),
autograd.Variable(torch.zeros(1, 1, self.latent_dim)))
self.decoder = nn.LSTM(latent_dim, input_dim)`
self.decoder_hidden = (autograd.Variable(torch.zeros(1, 1, self.input_dim)),
autograd.Variable(torch.zeros(1, 1, self.input_dim)))
The encoding step seems to make sense. For each vector in the input sequence compute the encoding and hidden state, passing the hidden state along to the next call to the LSTM.
Here is the encode step:
def encode(self, word_vectors):
out = None
for word_vec in word_vectors:
out, self.encoder_hidden = self.encoder(word_vec.view(1, 1, -1),
self.encoder_hidden)
return out
The decoding step however, I do not understand. I have seen references in various posts/docs saying that the last hidden state of the encoding LSTM should be the first hidden state of the decoding LSTM. However this can not be the case as the hidden states are of different dimensions (the size of the hidden and cell state must be the same as the output for an LSTM, it seems).
Therefore I have been trying to use the output of the encoding LSTM as the first input into the decoding LSTM. This works for the first step, but then there is nothing of latent_dim
to use past the first iteration.
Here is my decode step:
def decode(self, encoded, target_length):
outputs = []
out = encoded
for i in range(target_length):
out, self.decoder_hidden = self.decoder(out.view(1, 1, -1),
self.decoder_hidden)
outputs.append(out)
return outputs
This will fail at i=1
because out
is of the wrong dimenstion.
Anybody have any ideas how I could go about this?