Seq2Seq - Translation - Feedback appreciated

Hi everyone,

I built a simple sequence-to-sequence model, without attention, to translate senteces from german to english.

Here’s the model:

class Encoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(input_size, embedding_size)
        self.dropout = nn.Dropout(p=0.5)
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers=2, dropout=0.5)

    def forward(self, input):
        # torch.Size([21, 32])
        x = self.dropout(self.embed(input))
        # torch.Size([21, 32, 300])
        outputs, (hidden_state, cell_state) = self.lstm(x)
        # torch.Size([21, 32, 1024]) torch.Size([2, 32, 1024]) torch.Size([2, 32, 1024])

        return hidden_state, cell_state


class Decoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(input_size, embedding_size)
        self.dropout = nn.Dropout(p=0.5)
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers=2, dropout=0.5)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden_state_encoder, cell_state_encoder):
        # torch.Size([1, 32])
        x = self.dropout(self.embed(input))
        # torch.Size([1, 32, 300])
        outputs, (hidden_state, cell_state) = self.lstm(x, (hidden_state_encoder, cell_state_encoder))
        # torch.Size([1, 32, 1024]) torch.Size([2, 32, 1024]) torch.Size([2, 32, 1024])
        outputs = outputs.permute(1, 0, 2)  # nn.Linear expects (batch_size, *, in_features)
        # torch.Size([32, 1, 1024])
        y_hat = self.fc(outputs)
        # torch.Size([32, 1, 4556])
        return y_hat, hidden_state, cell_state


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, output_size):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.output_size = output_size

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        trg_seq_len = trg.size(0)
        batch_size = src.size(1)

        outputs = torch.zeros(trg_seq_len, batch_size, self.output_size).to(device)
        # torch.Size([*, 32, 4556])

        hidden, cell = self.encoder(src)

        x = trg[0, :]  # getting first tokens of batch as first input to decoder
        x = x.unsqueeze(0)
        # torch.Size([1, 32])
        for i in range(1, trg_seq_len):
            y, hidden, cell = self.decoder(x, hidden, cell)
            y = y.permute(1, 0, 2)
            # torch.Size([1, 32, 4556])
            outputs[i] = y

            teacher_force = random.random() < teacher_forcing_ratio

            x = trg[i].unsqueeze(0) if teacher_force else y.argmax(2)

        return outputs

The network seems to be learning, loss is decreasing nicely on the training set, however I’m having trouble decreasing loss beyond a certain point on the validation set. As you can see I’ve already added some regularization and while that has helped a tiny bit, it hasn’t really improvied the results by much.
Especially looking at the actual translations it produces, the results aren’t all that great (though some are 100% correct).

Now, without writing a more complex model (I’m trying to work my way there, just starting out with deep learning), is there anything I can do, to improve the results?
Is there anything I’m doing wrong maybe?

I’m using the Multi30k dataset and the following:

criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)
optimizer = torch.optim.Adam(seq2seq_model.parameters(), lr=0.003)

and also

torch.nn.utils.clip_grad_norm_(seq2seq_model.parameters(), 1)

during training.

Here is the progress I’m seeing while training:

1, progress:5%, train-loss: 4.4046, vald-loss: 3.9111, train-perp: 81.8246, vald-perp: 49.9540
2, progress:10%, train-loss: 3.8298, vald-loss: 3.5941, train-perp: 46.0512, vald-perp: 36.3840
3, progress:15%, train-loss: 3.6142, vald-loss: 3.4860, train-perp: 37.1226, vald-perp: 32.6539
4, progress:20%, train-loss: 3.4447, vald-loss: 3.3735, train-perp: 31.3325, vald-perp: 29.1815
5, progress:25%, train-loss: 3.3201, vald-loss: 3.2923, train-perp: 27.6621, vald-perp: 26.9046
6, progress:30%, train-loss: 3.2181, vald-loss: 3.2581, train-perp: 24.9803, vald-perp: 25.9999
7, progress:35%, train-loss: 3.1210, vald-loss: 3.1831, train-perp: 22.6684, vald-perp: 24.1222
8, progress:40%, train-loss: 3.0291, vald-loss: 3.2246, train-perp: 20.6779, vald-perp: 25.1444
9, progress:45%, train-loss: 2.9529, vald-loss: 3.1308, train-perp: 19.1617, vald-perp: 22.8918
10, progress:50%, train-loss: 2.8898, vald-loss: 3.1745, train-perp: 17.9897, vald-perp: 23.9156
11, progress:55%, train-loss: 2.8239, vald-loss: 3.1459, train-perp: 16.8417, vald-perp: 23.2398
12, progress:60%, train-loss: 2.7759, vald-loss: 3.1733, train-perp: 16.0525, vald-perp: 23.8866
13, progress:65%, train-loss: 2.7048, vald-loss: 3.2212, train-perp: 14.9520, vald-perp: 25.0574
...

Any pointers or feedback would be much appreciated!

Since you mentioned the train loss values are decreasing, it implies that your model is correctly implemented. However, you would still like to take your train loss to close to zero. Some of the things you can try to decrease validation loss:

  1. Ensure your dataset is of high quality and nicely split between train and validation set.
  2. Try and slightly overfit the model (before adding regularization techniques like dropout). You can feed in more data, train for more epochs and increase the complexity of the model (eg: use Bidirectional LSTM for encoder)
  3. Hyperparameter tuning at the end.

Also, use BLEU score (or similar metric) for understanding the quality of machine translation.

Thank you, your answer is much appreciated.
I have tried bidirectional and also increasing complexity otherwise (more layers, played with hidden and embedding dimensions, …) but nothing seemed to make a real difference.
One of the things that is surprising me the most about this learning experience, is how little regularization actually helped. I was under the impression it would be way more effective. It definitely leads to the train-loss decreasing slower but sadly it doesn’t lead to the validation-loss decreasing further to nearly the same extent.
I will definitely look at BLEU score and try to find a bigger dataset, thank you!

I can’t see anything obvious that looks suspicious, so here just some comments as food for thoughts. By the way, I appreciate that you keep track your dimensions in the comments :). This always helps a lot, I think.

  • Can you overfit the data on a (very) small training dataset, i.e., can you achieve a training loss of 0 if you train over only, say, 50 sentence pairs?

  • I don’t have any solid argument, but a dropout after the the embedding layer always looks a bit weird to me.

  • From your code it seems you use batches of size > 1 and padding. I’m actually not quite sure how this works for the decoder part. The official Seq2Seq tutorial of PyTorch uses batches of size 1 to avoid this issue; sacrificing performance of course. Anyway, for testing, I would try to train using batches of size 1 to avoid any kind of padding for the input and target sequence.

  • Not sure how your sequences look like, but I assume you properly add and tokens to your sequences before adding the padding token.

  • Lastly, machine translation is not a trivial task just because big companies might be able to do it almost flawlessly by now. So it’s not obvious to say what kind of loss can be expected in the end given basic model. Hence, my suggestion to see if you can overfit the model over a very small dataset.

That’s actually a quote I’d be willing to challenge :). Just because a model learns something doesn’t mean all is good. For example, I’ve seen dropout layers after the output layers. The loss is arguably still going down at least to some degree.

I really appreciate all your suggestions!

Yes, after about 200 epochs the loss is 0.0002. If I kept going it would probably get to 0 but at that point training is pretty slow.

I don’t really have a solid argument for that either. I saw someone doing it and gave it a shot. I don’t really get why it would be a good idea either though.

I assumed the padding didn’t matter since the loss-function is ignoring it anyway. You’re right though, I could try it without padding.

I’m not sure if there’s something missing in that sentence or if I’m just not getting it. Could you elaborate on that?