Seq2seq LSTM: Difference between decoders that loop over sequence and those that don't?

Hi everyone,

My first post here - I really enjoy working with PyTorch but I’m slowly getting to the point where I’m not able to answer any questions I have by myself anymore. :slight_smile:

I’m trying to forecast time series with an seq2seq LSTM model, and I’m struggling with understanding the difference between two variations of these models that I have seen. In one variety, there’s a loop in the decoder part over all steps t in the sequence, in the other one, all steps go in and come out at once. To illustrate this with two minimal code examples:

Variant A: With loop over the decoder part (an unrolled decoder (?)):

class Seq2Seq(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, batch_size, sequence_length):
        super(Seq2Seq, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.encoder_lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.decoder_lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers)
        self.linear = nn.Linear(self.hidden_size, 1)

    def forward(self, input):
        _, hidden = self.encoder_lstm(input)
        input_t = torch.zeros(batch_size, 1, dtype=torch.float).unsqueeze(0)
        output_tensor = torch.zeros(sequence_length, batch_size, 1)
        for t in range(self.sequence_length):
            output_t, hidden = self.decoder_lstm(input_t, hidden)
            output_t = self.linear(output_t[-1])
            input_t = output_t.unsqueeze(0)
            output_tensor[t] = output_t
        return output_tensor

Variant B: Without loop over the decoder part (a rolled decoder (?)):

class Seq2SeqB(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(Seq2SeqB, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, 1)

    def forward(self, input):
        output, hidden = self.lstm(input)
        output = self.linear(output)
        return output

I understand the code (I hope), and both models work, but what I don’t really get is the difference in intuition behind both approaches. When training the second model, without loop, on a simple sine wave data (like what is used here in this official PyTorch tutorial)
, without loop, loss decreases much much faster than for the second one. But I’m wondering if I somehow cheating myself here, since most seq2seq models I’ve seen use a loop in the decoder (and to make actual forecasts into the future I would still have to add a loop either way).

Any help would be much appreciated.

Best,
Chris



P.S., just for reference: Full minimal code with generated data and training:

import torch
from torch import nn, optim
import numpy as np

# Generate data

sinewave = np.sin(np.arange(0, 2000, 0.1))
slices = sinewave.reshape(-1, 200)
input_tensor = torch.tensor(slices[:, :-1], dtype=torch.float).unsqueeze(2)
target_tensor = torch.tensor(slices[:, 1:], dtype=torch.float)
print(input_tensor.shape, target_tensor.shape)


# Model - seq2seq model with loop over decoder

class Seq2SeqA(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, batch_size, sequence_length):
        super(Seq2SeqA, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.encoder_lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.decoder_lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers)
        self.linear = nn.Linear(self.hidden_size, 1)

    def forward(self, input):
        _, hidden = self.encoder_lstm(input)
        input_t = torch.zeros(batch_size, 1, dtype=torch.float).unsqueeze(0)
        output_tensor = torch.zeros(sequence_length, batch_size, 1)
        for t in range(self.sequence_length):
            output_t, hidden = self.decoder_lstm(input_t, hidden)
            output_t = self.linear(output_t[-1])
            input_t = output_t.unsqueeze(0)
            output_tensor[t] = output_t

        return output_tensor

seq2seqA = Seq2SeqA(input_size=1, hidden_size=51, num_layers=1, batch_size=100, sequence_length=199)


# Training - seq2seq model with loop over decoder

num_epochs = 300
criterion = nn.MSELoss()
optimizer = optim.Adam(seq2seqA.parameters(), lr=0.001)

for epoch in range(num_epochs):
    optimizer.zero_grad()
    output = seq2seqA(input_tensor)
    output = output.squeeze().transpose(1,0)
    loss = criterion(output, target_tensor)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print('Epoch: {} -- Training loss (MSE) {}'.format(epoch, loss.item()))


# Model - seq2seq model without loop over decoder

class Seq2SeqB(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(Seq2SeqB, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, 1)

    def forward(self, input):
        output, hidden = self.lstm(input)
        output = self.linear(output)
        return output

seq2seqB = Seq2SeqB(input_size=1, hidden_size=51, num_layers=2)


# Training- seq2seq model without loop over decoder

num_epochs = 300
criterion = nn.MSELoss()
optimizer = optim.Adam(seq2seqB.parameters(), lr=0.001)

for epoch in range(num_epochs):
    optimizer.zero_grad()
    output = seq2seqB(input_tensor)
    output = output.squeeze()
    loss = criterion(output, target_tensor)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print('Epoch: {} -- Training loss (MSE) {}'.format(epoch, loss.item()))

3 Likes

I have the same question! I believe it might have to do with the data format? For instance, if we are doing batch_first, then if we can preserve the hidden state, and sequentialize the data, then we should be able to pass in K batches corresponding to a regular old 1 x K batch when we do the unrolling. But I am not 100% sure…

Seq2SeqB is not a encoder-decoder architecture, you only have one LSTM layer (essentially just the decoder). Since you use the output and not just the last hidden state, this setup is more appropriate for the task of Sequence Labelling (e.g., Part-of-Speech Tagging or Named Entity Recognition in NLP). Seq2Seq, in principle, encode a complete sequence and generate a new sequence base on that encoding.

To make the problems with using Seq2SeqB more tangible, let me show why couldn’t (meaningfully) use it for, say, machine translation:

  • In case for Seq2SeqB the output sequence is always of the same length as the input sequence. In machine translation, this is arbitrary unlikely.

  • Case of Seq2SeqB, output[0] only depends on input[0], output[1] depends only on input[0:1], ouput[2] only depends on input[0:2], and so on – a BiLSTM is a bit more complex. However, a true Seq2Seq model first encodes the whole sentence since even the first word might depend on words or phrases anywhere in the input sentence.

In short, both your models are “many-to-many” models (in contrast to using an LSTM for classification which would be a “many-to-one” model). However, Seq2SeqB is the common approach for sequence lebelling, while Seq2Seq is a true encoder-decoder model for sequence generation.

1 Like