Bert for summarization (BertSum) problems

I am trying to implement the model as defined here --> https://github.com/nlpyang/BertSum

Each input to bertsum can be considered a document with multiple sentences separated by [SEP][CLS]

In order to ensure that I have the basic model working, I am trying to overfit the model with a toy dataset where I randomly assign sentences to uppercase or lowercase and I am trying to get the model to predict all uppercase = 1 and all lowercase = 0.

The model does not converge when the input has multiple sentences separated by CLS. It appears to me that the CLS tokens are not training. However, when I have a single sentence and therefore a single CLS token it works (basic classification task).

Any ideas what may be wrong? My notebook with all debugging output is here --> https://colab.research.google.com/drive/1NvrwbNAdYMguvu4KR2ReZUBHmj93PNkP?usp=sharing

Can you try using the Adam optimizer instead of SDG optimizer?

Its worse with Adam… The predictions become overconfident and loss stops changing after a while