Hi all,
I’m building a model for multiclass text classification with BERT. When i experimented with HuggingFace’s Trainer(), the training accuracy/f1 are approximately around 21% with only 5 epochs. However, when i tried to implement BertForSequenceClassification for multiclass (num_labels=30) with the training loop below, my accuracy/f1 keeps being around 5%. I would love a second pair of eyes to see whether I make a mistake at any point in this code.
#=================================================================
# One-hot encoding label - MULTICLASS
#=================================================================
lb = LabelBinarizer()
df['label'] = lb.fit_transform(df['label']).tolist()
num_labels = len(lb.classes_) # 35 classes
#=================================================================
# TRAINING
#=================================================================
MAX_LEN = 128
TRAIN_BATCH_SIZE = 8
VAL_BATCH_SIZE = 4
EPOCHS = 1
LEARNING_RATE = 0.01
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels = num_labels).to(device)
def train(epoch):
tr_loss = 0
n_correct = 0
nb_tr_steps = 0
nb_tr_examples = 0
model.train()
for _,data in enumerate(train_loader, 0):
ids = data['ids'].to(device)
mask = data['mask'].to(device)
token_type_ids = data['token_type_ids'].to(device)
targets = data['targets'].type(torch.LongTensor)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(ids, mask, token_type_ids)
loss = criterion(outputs.logits, torch.tensor(targets, dtype=float))
tr_loss += loss.item()
pred = torch.argmax(outputs.logits, dim = 1)
actual = torch.argmax(targets, dim=1)
n_correct += f1_score(pred.clone().detach().cpu(),actual.clone().detach().cpu(), average = 'weighted')
nb_tr_examples += TRAIN_BATCH_SIZE
nb_tr_steps += 1
loss.backward()
optimizer.step()
print(f'The Avg F1 for Epoch {epoch}: {(n_correct*100)/nb_tr_examples}')
epoch_loss = tr_loss/nb_tr_steps
epoch_accu = (n_correct)/nb_tr_examples
print(f"Training Loss Epoch: {epoch_loss}")
print(f"Training Accuracy Epoch: {epoch_accu}")