My custom BERT model’s architecture:
class BertArticleClassifier(nn.Module):
def __init__(self, n_classes, freeze_bert_weights=False):
super(BertArticleClassifier, self).__init__()
self.bert = AutoModel.from_pretrained('bert-base-uncased')
if freeze_bert_weights:
for param in self.bert.parameters():
param.requires_grad = False
self.dropout = nn.Dropout(0.1)
self.fc_1 = nn.Linear(768, 256)
self.leaky_relu = nn.LeakyReLU()
self.fc_out = nn.Linear(256, n_classes)
def forward(self, input_ids, attention_mask):
output = self.bert(input_ids, attention_mask)
return self.fc_out(self.leaky_relu(self.fc_1(self.dropout(output['pooler_output']))))
self.bert
is a model from transformers library.
Training script:
def train_my_model(model, optimizer, criterion, scheduler, epochs, dataloader_train, dataloader_validation, device, pretrained_weights=None):
if pretrained_weights:
torch.save(model.state_dict(), pretrained_weights)
for epoch in tqdm(range(1, epochs + 1)):
model.train()
loss_train_total = 0
progress_bar = tqdm(dataloader_train, desc=f'Epoch {epoch :1d}', leave=False, disable=False)
for batch in progress_bar:
optimizer.zero_grad()
batch = tuple(batch[b].to(device) for b in batch)
input_ids, mask, labels = batch
predictions = model(input_ids, mask)
loss = criterion(predictions, labels)
loss.backward()
loss_train_total += loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item() / len(batch))})
torch.save(model.state_dict(), f'models_data/bert_my_model/finetuned_BERT_epoch_{epoch}.model')
tqdm.write(f'\nEpoch {epoch}')
loss_train_avg = loss_train_total / len(dataloader_train)
tqdm.write(f'Training loss: {loss_train_avg}')
val_loss, predictions, true_vals = evaluate(model, dataloader_validation, criterion, device)
val_f1 = f1_score_func(predictions, true_vals)
tqdm.write(f'Validation loss: {val_loss}')
tqdm.write(f'F1 Score (Weighted): {val_f1}')
Optimizer and Criterion:
optimizer = AdamW(model.parameters(),
lr=1e-4,
eps=1e-6)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights).to(device)
After 5 epochs I get the same validation loss ~3.1. I know that my data is preprocessed in the correct way because if I train this transformers BertForSequenceClassification
model, the model is learning, but the problem with that approach is that I cannot tweak the loss function to accept the class weights, so that is the reason for creating my own custom model.
As you can see in the model’s forward
method, I extract the output['pooler_output']
piece, and disregard the loss (which is returned alongside the output['pooler_output']
element). The problem which I may deduced is that when in the training loop I call loss.backward()
, maybe the model’s weights aren’t updating, because transformers BERT model’s return their own loss as an output.
What am I doing wrong?