Seq2Seq model not overfitting on a small sampled dataset

I am trying to reproduce the paper “Learning to Ask: Neural Question Generation for Reading Comprehension”. Our initial results didn’t make sense to us so we decided to overfit on a small subset of our original dataset of around 1000 question, sentence pairs.

After training for sufficient time the training loss goes down to approx 0.4. However when I evaluate the model on the same overfitted dataset it’s always predicting the same token. I found this suspicious so I decided to check my predicitions on a train iteration to check if there is a match between the ground truth and the prediction.

Essentially I did an argmax over the output of the decoder before the prediction was sent to the loss function. I observed that a lot of words match with the ground truth here, which explains why the loss was so low. Yet when I run the model on my evaluate function, where I make a prediction word by word, updating the hidden state at every prediction, the same word is always predicted.

Can anyone give some guidance on where I could be going wrong? I am adding my code for prediction here:

def greedy_search(encoder: EncoderBILSTM, decoder: DecoderLSTM, dev_loader: DataLoader, use_cuda: bool, dev_idx_to_word_q: dict, dev_idx_to_word_a: dict, batch_size: int) -> None:
    encoder.eval()
    decoder.eval()
    max_len = 30
    for batch in dev_loader:

        questions, questions_org_len, answers, answers_org_len, pID = batch

        if use_cuda:
            questions = questions.cuda()
            questions_org_len = torch.LongTensor(np.asarray(questions_org_len)).cuda()
            answers = answers.cuda()
            answers_org_len = torch.FloatTensor(np.asarray(answers_org_len))

        encoder_input, encoder_len = answers, np.asarray(answers_org_len)
        decoder_input, decoder_len = questions, questions.shape[1]

        encoder_len = torch.LongTensor(encoder_len)
        if use_cuda:
            encoder_len = torch.LongTensor(encoder_len).cuda()
            decoder_inp = torch.ones((batch_size, 1), dtype=torch.long).cuda()
        else:
            encoder_len = torch.LongTensor(encoder_len)
            decoder_inp = torch.ones((batch_size, 1), dtype=torch.long)
        encoder_out, encoder_hidden = encoder(encoder_input, encoder_len)
        decoder_hidden = encoder_hidden
        # input to the first time step of decoder is <SOS> token.

        seq_len = 0
        eval_mode = False
        predicted_sequences = []
        while seq_len < max_len:
            seq_len += 1
            decoder_out, decoder_hidden = decoder(decoder_inp, decoder_hidden, encoder_out, answers_org_len,
                                                  eval_mode=eval_mode)

            # obtaining log_softmax scores we need to minimize log softmax over a span.
            decoder_out = decoder_out.view(batch_size, -1)
            decoder_out = torch.nn.functional.log_softmax(decoder_out)
            prediction = torch.argmax(decoder_out, 1).unsqueeze(1)
            predicted_sequences.append(prediction)
            decoder_inp = prediction.clone()
            eval_mode = True