Hidden units saturate in a seq2seq model in PyTorch

I’m trying to write a very simple machine translation toy example in PyTorch. To simply the question, I turn the machine translation task into this one:

Given a random sequence ([4, 8, 9 ...]), predict a sequence whose elements its elements plus 1 ([5, 9, 10, ...]). Id: 0, 1, 2 will be used as pad, bos, eos, respectively.

I observed the same problem in this toy task in my machine translation task. To debug, I use a very small data size n_data = 50, and find that the model can not even overfit these data. Looking into the model, I find that, the hidden state of the encoder/decoder soon becomes saturated, namely, all units in the hidden state become very close to 1/-1 due to the tanh.

-0.8987  0.9634  0.9993  ...  -0.8930 -0.4822 -0.9960
-0.9673  1.0000 -0.8007  ...   0.9929 -0.9992  0.9990
-0.9457  0.9290 -0.9260  ...  -0.9932  0.9851  0.9980
          ...             ⋱             ...
-0.9995  0.9997 -0.9350  ...  -0.9820 -0.9942 -0.9913
-0.9951  0.9488 -0.8894  ...  -0.9842 -0.9895 -0.9116
-0.9991  0.9769 -0.5871  ...   0.7557  0.9049  0.9881

Also, no matter how I adjust the learning rate, or switch the units to RNN/LSTM/GRU unit, the loss value seems to have a low bound even with 50 test samples. With more data, the model seems not converge at all.

step: 0, loss: 2.313938
step: 10, loss: 1.435780
step: 20, loss: 0.779704
step: 30, loss: 0.395590
step: 40, loss: 0.281261
...
step: 480, loss: 0.231419
step: 490, loss: 0.231410

When I use tensorflow, I can overfit such a dataset using a seq2seq model easily, and have a very small loss value.

Here are what I’ve tried:

  1. Manually initialize the embedding to very small numbers;
  2. Clipping the gradient to fixed norm like 1e-2, 2, 3, 5, 10.
  3. Excluding the padding index (by adding ignore_index to NLLLoss) when computing loss.

All of what I’ve tried did nothing help to the problem.

How can I get rid of this? Any help will be appreciated.

Here’s the code, for better reading experience, it’s on gist.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

np.random.seed(0)
torch.manual_seed(0)

_RECURRENT_FN_MAPPING = {
    'rnn': torch.nn.RNN,
    'gru': torch.nn.GRU,
    'lstm': torch.nn.LSTM,
}


def get_recurrent_cell(n_inputs,
                       num_units,
                       num_layers,
                       type_,
                       dropout=0.0,
                       bidirectional=False):
    cls = _RECURRENT_FN_MAPPING.get(type_)

    return cls(
        n_inputs,
        num_units,
        num_layers,
        dropout=dropout,
        bidirectional=bidirectional)


class Recurrent(nn.Module):

    def __init__(self,
                 num_units,
                 num_layers=1,
                 unit_type='gru',
                 bidirectional=False,
                 dropout=0.0,
                 embedding=None,
                 attn_type='general'):
        super(Recurrent, self).__init__()

        num_inputs = embedding.weight.size(1)
        self._num_inputs = num_inputs
        self._num_units = num_units
        self._num_layers = num_layers
        self._unit_type = unit_type
        self._bidirectional = bidirectional
        self._dropout = dropout
        self._embedding = embedding
        self._attn_type = attn_type
        self._cell_fn = get_recurrent_cell(num_inputs, num_units, num_layers,
                                           unit_type, dropout, bidirectional)

    def init_hidden(self, batch_size):
        direction = 1 if not self._bidirectional else 2
        h = Variable(
            torch.zeros(direction * self._num_layers, batch_size,
                        self._num_units))
        if self._unit_type == 'lstm':
            return (h, h.clone())
        else:
            return h

    def forward(self, x, h, len_x):
        # Sort by sequence lengths
        sorted_indices = np.argsort(-len_x).tolist()
        unsorted_indices = np.argsort(sorted_indices).tolist()
        x = x[:, sorted_indices]
        h = h[:, sorted_indices, :]
        len_x = len_x[sorted_indices].tolist()

        embedded = self._embedding(x)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, len_x)

        if self._unit_type == 'lstm':
            o, (h, c) = self._cell_fn(packed, h)
            o, _ = torch.nn.utils.rnn.pad_packed_sequence(o)
            return (o[:, unsorted_indices, :], (h[:, unsorted_indices, :],
                                                c[:, unsorted_indices, :]))
        else:
            o, hh = self._cell_fn(packed, h)
            o, _ = torch.nn.utils.rnn.pad_packed_sequence(o)
            return (o[:, unsorted_indices, :], hh[:, unsorted_indices, :])


class Encoder(Recurrent):
    pass


class Decoder(Recurrent):
    pass


class Seq2Seq(nn.Module):

    def __init__(self, encoder, decoder, num_outputs):
        super(Seq2Seq, self).__init__()
        self._encoder = encoder
        self._decoder = decoder
        self._out = nn.Linear(decoder._num_units, num_outputs)

    def forward(self, x, y, h, len_x, len_y):
        # Encode
        _, h = self._encoder(x, h, len_x)
        # Decode
        o, h = self._decoder(y, h, len_y)
        # Project
        o = self._out(o)

        return F.log_softmax(o)


def load_data(size,
              min_len=5,
              max_len=15,
              min_word=3,
              max_word=100,
              epoch=10,
              batch_size=64,
              pad=0,
              bos=1,
              eos=2):
    src = [
        np.random.randint(min_word, max_word - 1,
                          np.random.randint(min_len, max_len)).tolist()
        for _ in range(size)
    ]
    tgt_in = [[bos] + [xi + 1 for xi in x] for x in src]
    tgt_out = [[xi + 1 for xi in x] + [eos] for x in src]

    def _pad(batch):
        max_len = max(len(x) for x in batch)
        return np.asarray(
            [
                np.pad(
                    x, (0, max_len - len(x)),
                    mode='constant',
                    constant_values=pad) for x in batch
            ],
            dtype=np.int64)

    def _len(batch):
        return np.asarray([len(x) for x in batch], dtype=np.int64)

    for e in range(epoch):
        batch_start = 0

        while batch_start < size:
            batch_end = batch_start + batch_size

            s, ti, to = (src[batch_start:batch_end],
                         tgt_in[batch_start:batch_end],
                         tgt_out[batch_start:batch_end])
            lens, lent = _len(s), _len(ti)

            s, ti, to = _pad(s).T, _pad(ti).T, _pad(to).T

            yield (Variable(torch.LongTensor(s)),
                   Variable(torch.LongTensor(ti)),
                   Variable(torch.LongTensor(to)), lens, lent)

            batch_start += batch_size


def print_sample(x, y, yy):
    x = x.data.numpy().T
    y = y.data.numpy().T
    yy = yy.data.numpy().T

    for u, v, w in zip(x, y, yy):
        print('--------')
        print('S: ', u)
        print('T: ', v)
        print('P: ', w)


n_data = 50
min_len = 5
max_len = 10
vocab_size = 101
n_samples = 5

epoch = 100000
batch_size = 32
lr = 1e-2
clip = 3

emb_size = 50
hidden_size = 50
num_layers = 1
max_length = 15

src_embed = torch.nn.Embedding(vocab_size, emb_size)
tgt_embed = torch.nn.Embedding(vocab_size, emb_size)

eps = 1e-3
src_embed.weight.data.uniform_(-eps, eps)
tgt_embed.weight.data.uniform_(-eps, eps)

enc = Encoder(hidden_size, num_layers, embedding=src_embed)
dec = Decoder(hidden_size, num_layers, embedding=tgt_embed)
net = Seq2Seq(enc, dec, vocab_size)

optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

loader = load_data(
    n_data,
    min_len=min_len,
    max_len=max_len,
    max_word=vocab_size,
    epoch=epoch,
    batch_size=batch_size)

for i, (x, yin, yout, lenx, leny) in enumerate(loader):
    net.train()
    optimizer.zero_grad()

    logits = net(x, yin, enc.init_hidden(x.size()[1]), lenx, leny)
    loss = criterion(logits.view(-1, vocab_size), yout.contiguous().view(-1))

    loss.backward()

    torch.nn.utils.clip_grad_norm(net.parameters(), clip)
    optimizer.step()

    if i % 10 == 0:
        print('step: {}, loss: {:.6f}'.format(i, loss.data[0]))

    if i % 200 == 0 and i > 0:
        net.eval()
        x, yin, yout, lenx, leny = (x[:, :n_samples], yin[:, :n_samples],
                                    yout[:, :n_samples], lenx[:n_samples],
                                    leny[:n_samples])
        outputs = net(x, yin, enc.init_hidden(x.size()[1]), lenx, leny)
        _, preds = torch.max(outputs, 2)
        print_sample(x, yout, preds)

Using small numbers (around 1e-4) to initialize the embedding seems not help.

Perhaps try this

Thanks for your suggestion. I tried, but it did not help.

Clipping the gradient seems not help either.

This made me crazy…

I took a glance at the gist and didn’t see any obvious issues. I’d be happy to take a deeper look at this on Monday.

In your tf code, are you also using the exact same architecture/data/optim? (activation fn, lr, adam optimizer, sequence length, padding, …)

Thanks! About the data, sequences, preprocess(padding), they are exactly the same in my original project, even thought this is a simplified one. The architecture is also the same, RNN/LSTM/GRU encoder/decoder with only 1 layer. (The simplest seq2seq net, as described in the code). When training, only the learning rate is tuned in [1e-2, 1e-3, 1e-4].

Before, in tensorflow, I tried to initialize the embedding to large value (for comparison), e.g. uniform (-1, 1) and without gradient clipping, the hidden state in tensorflow never goes to -1/1 in thousands of steps. When the PyTorch version goes to -1/1 in dozens of steps.

I personally think that both should easily overfit a very small dataset. (n_data=20/50 in the code.)

I couldn’t really find anything wrong. If it weren’t for your tf results, I wouldn’t be surprised if this architecture can’t overfit this model. Basically you are asking for 10 log_2 100 bits entropy to be encoded in a hidden space of 50-dimensional hypercube with only a 1 layer rnn. (I don’t think the yin helps much as it is always one step behind). I’d expect it to not overfit.

That said, if tf can do it, pytorch should also be able to do it.

Sorry for not being helpful.

Thanks again for taking a look into this.

The original problem has a vocabulary sized 50000 and hidden size 1000, I think it should be enough model complexity to overfit the data? I apologize that I’m focusing too much in the -1/1 problem of the hidden units in this toy example and may not realize the model complexity. But, did you observe that the hidden unit saturated? I wrote a quick tf version for this toy example just now . Though the model complexity seems not enough too with the same setting of the original post, but the hidden states just, never goes to -1/1.

Ok, it seems that other people are interested in this.

The problem is related to this. In 0.2.0 post3 (August) version of PyTorch, the document says nothing about the dim parameter of log_softmax and what will it behave if dim==None. If anyone spends lots of time on things like this and gets no luck, I suggest you update to the master or use tensorflow.

1 Like