Training BERT-Base with SST2

I would like to fine-tune BERT-Base uncased on the SST2 dataset, using torch text. However, with my current setup, the model doesn’t seem to be training (i.e. accuracy on dev set stays the same). I have tried using the same hyper parameters that I have seen on Huggingface. Any help is appreciated.

I have all the latest versions of torch, transformers, sklearn, tqdm and any dependencies they may have.

Here is my training code:

import numpy as np
from sklearn.metrics import f1_score, precision_recall_fscore_support
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torchtext.datasets import SST2
from tqdm import tqdm
from transformers import BertTokenizer, AutoModel, BertForSequenceClassification

LR = 0.0005
EPOCHS = 5
BATCH_SIZE = 128

device = torch.device("cuda:" + "0" if torch.cuda.is_available() else "cpu")


model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
base_model = BertForSequenceClassification.from_pretrained(model_name)
max_input_length = 128


model = base_model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = AdamW(model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=0.01)


train_datapipe = SST2(split="train")
valid_datapipe = SST2(split="dev")
# Transform the raw dataset using non-batched API (i.e apply transformation line by line)
def collate_batch(batch):
    ids, types, masks, label_list = [], [], [], []
    for text, label in batch:
        tokenized = tokenizer(text,
                              padding="max_length", max_length=max_input_length,
                              truncation=True, return_tensors="pt")
        ids.append(tokenized['input_ids'])
        types.append(tokenized['token_type_ids'])
        masks.append(tokenized['attention_mask'])
        label_list.append(label)

    input_data = {
        "input_ids": torch.squeeze(torch.stack(ids)).to(device),
        "token_type_ids": torch.squeeze(torch.stack(types)).to(device),
        "attention_mask": torch.squeeze(torch.stack(masks)).to(device)
    }
    label_list = torch.tensor(label_list, dtype=torch.int64)
    return input_data, label_list


train_dataloader = DataLoader(train_datapipe, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate_batch)
valid_dataloader = DataLoader(valid_datapipe, batch_size=BATCH_SIZE, collate_fn=collate_batch)

# print("total instances: ", len(train_dataloader))
for epoch in range(EPOCHS):
    model.train()
    train_loss = []
    all_labels = []
    all_outs = []

    for i, (input_data, label) in enumerate(tqdm(train_dataloader)):
        output = model(**input_data).logits
        label = label.to(device)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        train_loss.append(loss_value)

        label = label.cpu().numpy()
        all_labels.extend(label)
        predicted_labels = torch.argmax(output, dim=-1).cpu().numpy()
        all_outs.extend(predicted_labels)

    # calc metric
    train_loss = np.mean(train_loss)
    print(train_loss)
    print(precision_recall_fscore_support(all_labels, all_outs))
    print("-" * 30)

    model.eval()
    val_loss = []
    val_accuracy = []
    with torch.no_grad():
        for i, (input_data, label) in enumerate(tqdm(valid_dataloader)):
            output = model(**input_data).logits
            label = label.to(device)
            loss = criterion(output, label)

            loss_value = loss.item()
            val_loss.append(loss_value)

            predicted_labels = torch.argmax(output, dim=-1)
            accuracy = (predicted_labels == label).cpu().numpy().mean() * 100
            val_accuracy.append(accuracy)

        print(np.mean(val_loss))
        print("accuracy:", np.mean(val_accuracy))
        print("=" * 30)

Here is the output:

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight'] - This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. 527it [05:27, 1.61it/s]

5.145069390818109 (array([0.44484046, 0.55993406]), array([0.44425789, 0.56051532]), array([0.44454898, 0.56022454]), array([29780, 37569])) ------------------------------

7it [00:01, 3.89it/s]

10.386857986450195 accuracy: 50.94436813186813 ==============================

527it [05:27, 1.61it/s]

19.96938668595772 (array([0.44756011, 0.56172413]), array([0.42501679, 0.58415183]), array([0.43599724, 0.57271849]), array([29780, 37569])) ------------------------------

7it [00:01, 3.89it/s]

43.005083901541575 accuracy: 50.94436813186813