Model loss difference between BertForSequenceClassification and Bert + nn.Linear

I was trying to fine tune BERT for a continuous outcome (ranges between 0-400). I was doing this with transformers and PyTorch Lightning in Google Colab. I noticed a big difference in validation loss during training between loading the pre-trained BERT with BertForSequenceClassification and loading with BertModel + writing nn.Linear, dropout, loss myself.

Specifically, using BertForSequenceClassification seems to work fine, with validation loss decreasing in each epoch. But if l load the same pre-trained BERT using BertModel.from_pretrained and writing the linear, dropout, and loss myself, the validation loss quickly stagnates:

The data, seed, hardware and much of the code are the same, with the only difference being the snippets below:

        self.config.num_labels = self.hparams.num_labels
        self.model = transformers.BertForSequenceClassification.from_pretrained(self.hparams.bert_path, config=self.config)
        self.tokenizer = transformers.BertTokenizer.from_pretrained(self.hparams.bert_path)

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        return loss


        self.config.num_labels = self.hparams.num_labels
        self.model = transformers.BertModel.from_pretrained(self.hparams.bert_path, config=self.config)
        self.tokenizer = transformers.BertTokenizer.from_pretrained(self.hparams.bert_path)
        self.drop = torch.nn.Dropout(p=self.hparams.dropout)
        self.out = torch.nn.Linear(self.model.config.hidden_size, self.hparams.num_labels)
        self.loss = torch.nn.MSELoss()

    def forward(self, input_ids, att_mask):
        res = self.model(input_ids = input_ids, attention_mask = att_mask)
        dropout_output = self.drop(res.pooler_output)
        out = self.out(dropout_output)
        return out

    def training_step(self, batch, batch_idx):
        outputs = self(input_ids = batch["input_ids"], att_mask = batch["attention_mask"])
        loss = self.loss(outputs.view(-1), batch['labels'].view(-1))
        return loss

I’m lost as to why this is the case. I especially want the second approach to work so I can build on this further. Any advice is greatly appreciated!

I am wondering if this is a classification task should you be using CrossEntropy loss instead of MSELoss?