I was trying to fine tune BERT for a continuous outcome (ranges between 0-400). I was doing this with transformers and PyTorch Lightning in Google Colab. I noticed a big difference in validation loss during training between loading the pre-trained BERT with BertForSequenceClassification and loading with BertModel + writing nn.Linear, dropout, loss myself.
Specifically, using BertForSequenceClassification seems to work fine, with validation loss decreasing in each epoch. But if l load the same pre-trained BERT using BertModel.from_pretrained and writing the linear, dropout, and loss myself, the validation loss quickly stagnates:
The data, seed, hardware and much of the code are the same, with the only difference being the snippets below:
self.config.num_labels = self.hparams.num_labels
self.model = transformers.BertForSequenceClassification.from_pretrained(self.hparams.bert_path, config=self.config)
self.tokenizer = transformers.BertTokenizer.from_pretrained(self.hparams.bert_path)
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = outputs[0]
return loss
and
self.config.num_labels = self.hparams.num_labels
self.model = transformers.BertModel.from_pretrained(self.hparams.bert_path, config=self.config)
self.tokenizer = transformers.BertTokenizer.from_pretrained(self.hparams.bert_path)
self.drop = torch.nn.Dropout(p=self.hparams.dropout)
self.out = torch.nn.Linear(self.model.config.hidden_size, self.hparams.num_labels)
self.loss = torch.nn.MSELoss()
def forward(self, input_ids, att_mask):
res = self.model(input_ids = input_ids, attention_mask = att_mask)
dropout_output = self.drop(res.pooler_output)
out = self.out(dropout_output)
return out
def training_step(self, batch, batch_idx):
outputs = self(input_ids = batch["input_ids"], att_mask = batch["attention_mask"])
loss = self.loss(outputs.view(-1), batch['labels'].view(-1))
return loss
I’m lost as to why this is the case. I especially want the second approach to work so I can build on this further. Any advice is greatly appreciated!