I am using BERT based encoding for a question answering tasks.
In short I have 2 BERT encoders which take in a question and a document as input, and predicts if the document contains answer to the question.
During training, I am trying to use in batch training. For e.g, if I have a batch of N question answer pairs, I am trying to train the model by constructing a NXN target matrix which have 1 in the [i, j] th position if question[i] and document[j] are matches, and 0 otherwise. Naturally, during creation of the batches, the i-th question and i-th document are matches. So the target matrix is essentially a diagonal matrix.
The cosine distance in a batch between all the question embeddings from BERT1 and all the document embeddings from BERT2 are passed through sigmoid layer and then BCE Loss is calculated by comparing this predicted NXN matrix to the NXN target matrix as described previously…
The problem is that during training, the model seems to not learn anything. It always predicts a 1 or a zero for all question document pairs.
Kindly let me know if I am missing something.
My code for computing the embeddings and the loss is detailed below.
class BiEncoder(nn.Module):
def __init__(self, tokenizer, bert_model_name):
super(BiEncoder, self).__init__()
self.bert_question = BertModel.from_pretrained(bert_model_name)
self.bert_document = BertModel.from_pretrained(bert_model_name)
def forward(self, input_ids_question, attention_mask_question, input_ids_document, attention_mask_document):
output_question = self.bert_question(
input_ids=input_ids_question,
attention_mask=attention_mask_question
)
output_document = self.bert_document(
input_ids=input_ids_document,
attention_mask=attention_mask_document
)
question_embedding = output_question.pooler_output
document_embedding = output_document.pooler_output
#Normalization
question_embedding = question_embedding / question_embedding.norm(dim = 1, keepdim = True)
document_embedding = document_embedding / document_embedding.norm(dim = 1, keepdim = True)
#Cosine Similarity
cosine_sim = torch.matmul(question_embedding, document_embedding.T)
return cosine_sim
class BiEncoderLossBatch(nn.Module):
def __init__(self):
super(BiEncoderLossBatch, self).__init__()
self.bceloss = nn.BCEWithLogitsLoss()
def forward(self, scores, targets):
"""scores: output of the BiEncoder Model. NXN tensor
targets: actual 0/1 targets. NXN tensor"""
scores = scores.flatten()
targets = targets.flatten()
loss = self.bceloss(scores, targets)
return loss