LSTM based auto-encoder overfitting

Here is my model:

Encoder:

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
       super().__init__()
    
      self.hid_dim = hid_dim
      self.n_layers = n_layers
    
      self.embedding = nn.Embedding(input_dim, emb_dim)
    
      self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
    
      self.dropout = nn.Dropout(dropout)
      self.layer_norm_1 = nn.LayerNorm(emb_dim)
    
  def forward(self, src):
      embedded = self.embedding(src) # removed dropout after 
      outputs, (hidden, cell) = self.rnn(embedded)
      return hidden, cell

Decoder:

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
    
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm_1 = nn.LayerNorm(emb_dim)
        
    def forward(self, input, hidden, cell):        
         input = input.unsqueeze(0)
         embedded = self.embedding(input) # removed dropout after embedding layer
                      
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        prediction = self.fc_out(output.squeeze(0))
    
       return prediction, hidden, cell

Training:

def train(encoder, decoder, src, trg, optimizer, loss, teacher_forcing_ratio = 0.5):
      encoder.train()
      decoder.train()

      for src, trg in zip(src, trg):
        optimizer.zero_grad()
        
        src = src.transpose(-1, 0)
        trg = trg.transpose(-1, 0)
        
        src = src.to(device)
        trg = trg.to(device)
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = decoder.output_dim
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device) # N x batch x vocab_size
        
        hidden, cell = encoder(src) # src = N x batch_size
        
        input = trg[0, :] 
        for t in range(1, trg_len):
            output_, hidden, cell = decoder(input, hidden, cell)
            outputs[t] = output_ 
            
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output_.argmax(1) 
            
            input = trg[t] if teacher_force else top1
        
        outputs = outputs[1:].reshape(-1, outputs.shape[2]) # list of words in the entire batch and it's predicted output
        trg = trg[1:].reshape(-1)
        
        l = loss(outputs, trg)
        l.backward()
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1)
        optimizer.step()

      return l.detach().cpu().numpy()

image

I tried various things such as :

  1. used dropout and keeping dropout=1 also gives the same problem
  2. increased training data from 10K to 20K
  3. decreased 2 layers LSTM to 1
  4. added weight decay (1e-4)
  5. shuffled data and gradient clipping

Any help would be appreciated.

First of all, are you sure your validation and training data is the right distribution? Maybe the training sentences are short and the validation are long, unsure.

Also, it seems like you are using trg_len = trg.shape[0] so all your targets in the batch are the same length. Is this true? Specifically, this might be a padding issue where the final state of the encoder you padd in is actually the result of several 0 propagations through the network. For example, if you have a source batch of two sentences of lengths (2, 10) and the target batch is of length (7, 8) then when you go thorough the encoder and you grab the “final” hidden state. But, because you needed to pad the first sentence with 8 0s to make it length 10, this state is noisy, so it will not help the decoder, you are passing noise to the decoder.

What is some example data? This might be way off btw, just something I saw before …