Text summarisation project

I’m doing my first project in pytorch as an undergraduate at university, namely abstractive text summarisation on the Gigaword dataset. I’ve got a single layer bidirectional encoder and a decoder as my base model which i’m training from scratch and I’ve also tried to add some attention mechanism and implement some kind of pointer ability to copy words directly/extractively from the input some of the time. I’m using teacher forcing too. And my embeddings are pretrained glove 300 dim.

I am using 1 million samples from the 3.5 million dataset and i’m really just looking for some help and advice as to what results i should expect to see and how to make it better.

The model is about 7/15 epochs into training at the moment and each epoch takes approximately 12 hours using colab pro + GPU’s. Is that to be expected for this kind of task? It seems kind of long to me.

as for results in training, i’m using cross entropy loss function with default learning rate. After about 2 epochs the loss went quickly from 11 to 4.5, and then has taken the other 5 epochs to get it below 4. It seems stuck there now. Again, is that something which is common, or should it be converging to zero faster?

I have greedy and beam search options for picking the decoded sentence. However in the screenshots below i’ve only included the greedy output which takes the highest probability in the distribution at each timestep.

I have some good results such as…

Some reasonably good …

and some badish…

Other Questions:

  • i’ve told the loss function to ignore the padding token. So would that explain the fact that it seemingly pads with ‘eos’ instead?

  • Since i’m doing batches of size 32, the decoder keeps outputting to the length of the LARGEST summary in the batch. Does this have a performance/quality implications?

  • Would you consider gradient clipping for a task like this? if so, what kind?

  • Any other hyperparameters worth playing with?

  • Would you expect better performance using the entire 3.5 million samples and lesser epochs?

If you would like me to post any of my code, I would be glad to. I just didn’t want to make a long post even longer.