Error when feeding the previous time step prediction into the decoder in the Encoder-Decoder architecture

Hello PyTorch developers,

I’m solving Exercise 4 from Chapter 9.7 in the Dive into Deep Learning book. It pertains to the sequence-to-sequence Encoder-Decoder architecture. I have an error, I know what it is, but I don’t know how to fix it. The exercise goes as follows:

In training, replace teacher forcing with feeding the prediction at the previous time step into the decoder. How does this influence the performance?

Below is the original training code which works; it uses teacher forcing to the decoder:

#@save
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    """Train a model for sequence to sequence."""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

    net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens
        for batch in data_iter:
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # Teacher forcing
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()  # Make the loss scalar for `backward`
            d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
          f'tokens/sec on {str(device)}')
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 300, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,
                         dropout)
decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,
                         dropout)
net = d2l.EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

Here are the encoder and the decoder classes, alongside the interfaces they’re implementing. There’s also a custom loss function defined here, but it’s irrelevant, so I’m ommiting it.

from torch import nn

#@save
class Encoder(nn.Module):
    """The base encoder interface for the encoder-decoder architecture."""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError
#@save
class Seq2SeqEncoder(d2l.Encoder):
    """The RNN encoder for sequence to sequence learning."""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
                          dropout=dropout)

    def forward(self, X, *args):
        # The output `X` shape: (`batch_size`, `num_steps`, `embed_size`)
        X = self.embedding(X)
        # In RNN models, the first axis corresponds to time steps
        X = X.permute(1, 0, 2)
        # When state is not mentioned, it defaults to zeros
        output, state = self.rnn(X)
        # `output` shape: (`num_steps`, `batch_size`, `num_hiddens`)
        # `state` shape: (`num_layers`, `batch_size`, `num_hiddens`)
        return output, state
#@save
class Decoder(nn.Module):
    """The base decoder interface for the encoder-decoder architecture."""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError
class Seq2SeqDecoder(d2l.Decoder):
    """The RNN decoder for sequence to sequence learning."""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqDecoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,
                          dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]

    def forward(self, X, state):
        # The output `X` shape: (`num_steps`, `batch_size`, `embed_size`)
        X = self.embedding(X).permute(1, 0, 2)
        # Broadcast `context` so it has the same `num_steps` as `X`
        context = state[-1].repeat(X.shape[0], 1, 1)
        X_and_context = torch.cat((X, context), 2)
        output, state = self.rnn(X_and_context, state)
        output = self.dense(output).permute(1, 0, 2)
        # `output` shape: (`batch_size`, `num_steps`, `vocab_size`)
        # `state` shape: (`num_layers`, `batch_size`, `num_hiddens`)
        return output, state
#@save
class EncoderDecoder(nn.Module):
    """The base class for the encoder-decoder architecture."""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

Now, here’s my modified training code, alongside with its output. I added some print statements so I can see what’s going on:

import numpy as np

#@save
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    """Train a model for sequence to sequence."""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

    net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens
        for batch in data_iter:
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            print("Y.shape:")
            print(Y.shape)
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            print("bos.shape:")
            print(bos.shape)
            # if it's the first time this is running, we don't have Y_hat, and we need it:
            if not "Y_hat" in locals(): # https://stackoverflow.com/questions/843277/how-do-i-check-if-a-variable-exists
                print("Here")
                Y_hat = torch.rand(size=(Y.shape[0], Y.shape[1], len(tgt_vocab)), device=device)
            print("Y_hat.shape:")
            print(Y_hat.shape)
            Y_hat = np.argmax(Y_hat.cpu().detach().numpy(), axis=2)
            Y_hat = torch.tensor(Y_hat).to(device)
            dec_input = torch.cat([bos, Y_hat[:bos.shape[0], :-1]], 1)
            Y_hat, _ = net(X, dec_input, X_valid_len)
            print("Y_hat.shape after net:")
            print(Y_hat.shape)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()  # Make the loss scalar for `backward`
            d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
          f'tokens/sec on {str(device)}')
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 300, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,
                         dropout)
print("len(src_vocab):")
print(len(src_vocab))
decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,
                         dropout)
print("len(tgt_vocab):")
print(len(tgt_vocab))
net = d2l.EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
Here's the output:

len(src_vocab):
184
len(tgt_vocab):
201
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Here
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([64, 10, 201])
Y.shape:
torch.Size([25, 10])
bos.shape:
torch.Size([25, 1])
Y_hat.shape:
torch.Size([64, 10, 201])
Y_hat.shape after net:
torch.Size([25, 10, 201])
Y.shape:
torch.Size([64, 10])
bos.shape:
torch.Size([64, 1])
Y_hat.shape:
torch.Size([25, 10, 201])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_71062/1885358373.py in <module>
     13 print(len(tgt_vocab))
     14 net = d2l.EncoderDecoder(encoder, decoder)
---> 15 train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

/tmp/ipykernel_71062/3336883265.py in train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device)
     39             Y_hat = np.argmax(Y_hat.cpu().detach().numpy(), axis=2)
     40             Y_hat = torch.tensor(Y_hat).to(device)
---> 41             dec_input = torch.cat([bos, Y_hat[:bos.shape[0], :-1]], 1) # added :bos.shape[0], otherwise I get an error
     42             Y_hat, _ = net(X, dec_input, X_valid_len)
     43             print("Y_hat.shape after net:")

RuntimeError: Sizes of tensors must match except in dimension 0. Got 25 and 64 (The offending index is 0)

I know why I get this error. It’s because when I go over the last batch in an epoch (the batch size then is 25) and I continue onto the next epoch, the first dimension of Y_hat is 25 (because that’s the way it’s stored from the last epoch), while it should be 64 (64 is the batch size).

How do I fix this? I need to have Y_hat before computing dec_input so I can concatenate the bos with Y_hat, but my current implementation has issues with that approach. How can I fix this?

If I may ask: Could you describe your plan for doing this in words, i.e. how do you deal with the sequence steps? From the code, I cannot really tell, I would have expected a for loop over the sequence index.

My original plan was this: I want to use the Encoder-Decoder neural network to predict the next characters in each batch. I’m storing the neural network outputs in the variable Y_hat. Since for the decoder input (dec_input) I need to concatenate bos (beginning of sequence tokens) and Y_hat and I don’t yet have Y_hat (since I need to run the inputs through the neural network first), I first randomly fill up Y_hat and then for each subsequent batch Y_hat is what it was from the previous batch.

Writing this out, it doesn’t make much sense because Y_hat should be the prediction of the neural network on the characters of this batch, not the previous one. However, I have to concatenate beginning of sequence tokens (bos) with Y_hat prior to passing the characters to the neural network. How can I accomplish this?

My impression is that you might be looking for a clever solution when that does not exist.
It is the very mundane that you are asked to do: You feed the input and the decoder just runs in a loop with one sequence index at a time (starting from some bos token).

If it’s not a problem for you, could you write up some code to show me what you mean?

Here’s how I understood it:

First of all, the Encoder, Decoder and EncoderDecoder interfaces are defined as follows:

from torch import nn

#@save
class Encoder(nn.Module):
    """The base encoder interface for the encoder-decoder architecture."""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError
#@save
class Decoder(nn.Module):
    """The base decoder interface for the encoder-decoder architecture."""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError
#@save
class EncoderDecoder(nn.Module):
    """The base class for the encoder-decoder architecture."""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

Now, here is the training code (unaltered from the book):

#@save
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    """Train a model for sequence to sequence."""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

    net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens
        for batch in data_iter:
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # Teacher forcing
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()  # Make the loss scalar for `backward`
            d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
          f'tokens/sec on {str(device)}')

What was your idea? Where does the decoder run in a loop? As you can see from the code above, the decoder gets fed with the dec_input variable, whose dimensions are torch.Size([64, 10]) (I checked it it code). Was your idea to feed the decoder something like the following:

<bos> first_token <pad> <pad> ... <pad>

then, at the second iteration:

<bos> first_token second_token <pad> ... <pad>

and so on?

I would appreciate it if you could clarify this a bit.

Thank you in advance!