The task is this:
For each query sentence, we want to find relevant answer sentences from around 800 candidates.
I use a kind of BERT-TFIDF ensemble for this task:
I compute the cosine similarity between the query sentence and answer sentence, based on BERT representation and TFIDF respectively, sum them up, min-max normalize, predict the query-answer pair is relevant if the value exceeds the threshold (this threshold is determined using the validation set).
The mystery is this:
When I make prediction using a general pretrained SentenceBERT model right after I download it (sonoisa/sentence-bert-base-ja-mean-tokens-v2 · Hugging Face), untrained on my dataset, I get a f2 score of 69.
But after I trained this model on my specific dataset, the scores of the ensemble became lower, even though when comparing the BERT alone (not taking ensemble with TFIDF) the trained model performs better than the untrained one.
Sorry if my description is hard to follow. In short,
Trained BERT model does better than Untrained BERT model, however, when combined with TFIDF, Untrained BERT does better than Trained BERT. How could that be?
I have omitted details in order to keep it concise, but please let me clarify if you think necessary details is lacking. For reference, I put the training loop for the BERT model below.
I would greatly appreciate any suggestions. Thank you!
criterion = nn.CosineEmbeddingLoss() for i, batch in enumerate(dataloader): sent_a_input_ids = batch.to(device) sent_a_attention_mask = batch.to(device) sent_b_input_ids = batch.to(device) sent_b_attention_mask = batch.to(device) labels = torch.where(batch==0, -1, 1).to(device) # -1 if dissimilar 1 if similar a_out = model(input_ids=sent_a_input_ids, attention_mask=sent_a_attention_mask).last_hidden_state b_out = model(input_ids=sent_b_input_id, attention_mask=sent_b_attention_mask).last_hidden_state a_emb = _mean_pooling(a_out, sent_a_attention_mask) b_emb = _mean_pooling(b_out, sent_b_attention_mask) loss = criterion(a_emb, b_emb, labels) loss.backward()