LSTMCell autoencoder input shape problems

I need some help. I’m having some problems setting up an basic LSTM autoencoder (without attention or anything fancy).

This implementation follows a paper that uses this implementation:

Encoder: Standard LSTM layer. Input sequence is encoded in the final hidden state.
Decoder: Reconstruct the sequence one element at a time, starting with the last element x[N].

Decoder algorithm is as follows for a sequence of length N:

  1. Get Decoder initial hidden state `hs[N]: Just use encoder final hidden state.
  2. Reconstruct last element in the sequence: x[N]= w.dot(hs[N]) + b.
  3. Same pattern for other elements: x[i]= w.dot(hs[i]) + b
  4. use x[i] and hs[i] as inputs to LSTM to get x[i-1] and hs[i-1]

The problem is that at each time step, x[i] and hs[i] must be the same shape in order to be input to an LSTM cell, but they are not. That’s because x is the shape of the original sequence data, where hs is the shape of the embedded/encoded vector. There’s no way to make them the same…

What gives?

Encoder:

class SeqEncoderLSTM(nn.Module):
    def __init__(self, n_features, hidden_size):
        super(SeqEncoderLSTM, self).__init__()
        
        self.lstm = nn.LSTM(
            n_features, 
            hidden_size, 
            batch_first=True)
        
    def forward(self, x):
        _, hs = self.lstm(x)
        return hs

Decoder:

class SeqDecoderLSTM(nn.Module):
    def __init__(self, enc_hs, emb_size, n_features, seq_len):
        super(SeqDecoderLSTM, self).__init__()
        
        self.seq_len = seq_len
        
        self.cell = nn.LSTMCell(emb_size, emb_size)
        self.dense = nn.Linear(emb_size, n_features)
        
    def forward(self,x, hs_0):
        
        # add each point to the sequence as it's reconstructed
        x = torch.tensor([])
        
        # Initial hidden state of decoder is final hidden state of encoder
        hs_i, cs_i = hs_0
        
        # reconstruct first (last) element
        x_i = self.dense(hs_i)
        x = torch.cat([x, x_i])
        
        # reconstruct remaining elements
        for i in range(1, self.seq_len):
            print("last x: ", x.shape)
            print("hs shape: ", hs_i.shape)
            print("cs shape: ", cs_i.shape)
            hs_i, cs_i = self.cell(x_i, (hs_i, cs_i))
            x_i = self.dense(hs)
            torch.cat([x, x_i])

Bringing it all together:


class LSTMEncoderDecoder(nn.Module):
    def __init__(self, n_features, emb_size, seq_len):
        super(LSTMEncoderDecoder, self).__init__()
        self.n_features = n_features
        self.hidden_size = emb_size

        self.encoder = SeqEncoderLSTM(n_features, emb_size)
        self.decoder = SeqDecoderLSTM(emb_size, emb_size, n_features, seq_len)
    
    def forward(self, x):
        hs = self.encoder(x)
        hs = tuple([h.squeeze(0) for h in hs])
        out = self.decoder(x, hs)
        return out
        

I feel the issue is with

which should be self.cell = nn.LSTMCell(n_features, emb_size).

I assume(!), looking at

3. Same pattern for other elements: x[i]= w.dot(hs[i]) + b 

that self.dense in your decoder converts the first hidden state into input. However, I can’t see what

x = torch.cat([x, x_i])

is doing.

Good Morning, and thank you :slight_smile:

The whole time I was thinking with my fingers: “The input to the LSTM cell is the embedding from the encoder”, even though I knew that the input to the LSTM cell is the last reconstructed input. I’ve been staring at that diagram for two days!

to your question about torch.cat(), I’m collecting each reconstructed element of the sequence back into a sequence…

Is that bad/wrong?

With neural networks, that’s always a tricky question :).

Really complaining are networks only when the shapes of the tensors to workout as inputs for the layers. That’s essentially the only case where PyTorch throws errors. And since this seemed to be your concern here, I merely tried to get a sense of what’s going on based on the shapes.

If a network does not throw any error, one can sometimes quickly spot that something is off. For example, a couple of weeks ago, someone tried to figure out why his RNN-based classifier wasn’t working, and it turnout that he used the first and not the last hidden state for further processing.

Beyond that, if something is bad or wrong depends on the data, the task, and whatnot.

Well feed forward and training is happening… But it appears that the model just converges to predicting the sequence average value… a flat line. Bummer.

But, I’ve marked your answer as my solution and I’ll ask a separate question about that!