In Batch All vs All Loss Function Implementation

I am using BERT based encoding for a question answering tasks.
In short I have 2 BERT encoders which take in a question and a document as input, and predicts if the document contains answer to the question.

During training, I am trying to use in batch training. For e.g, if I have a batch of N question answer pairs, I am trying to train the model by constructing a NXN target matrix which have 1 in the [i, j] th position if question[i] and document[j] are matches, and 0 otherwise. Naturally, during creation of the batches, the i-th question and i-th document are matches. So the target matrix is essentially a diagonal matrix.

The cosine distance in a batch between all the question embeddings from BERT1 and all the document embeddings from BERT2 are passed through sigmoid layer and then BCE Loss is calculated by comparing this predicted NXN matrix to the NXN target matrix as described previously…

The problem is that during training, the model seems to not learn anything. It always predicts a 1 or a zero for all question document pairs.

Kindly let me know if I am missing something.

My code for computing the embeddings and the loss is detailed below.

class BiEncoder(nn.Module):
	def __init__(self, tokenizer, bert_model_name):
		super(BiEncoder, self).__init__()
		self.bert_question = BertModel.from_pretrained(bert_model_name)
			
		self.bert_document = BertModel.from_pretrained(bert_model_name)

	def forward(self, input_ids_question, attention_mask_question, input_ids_document, attention_mask_document):

		output_question = self.bert_question(
			input_ids=input_ids_question,
			attention_mask=attention_mask_question
		)

		output_document = self.bert_document(
			input_ids=input_ids_document,
			attention_mask=attention_mask_document
		)
		question_embedding = output_question.pooler_output
		document_embedding = output_document.pooler_output
#Normalization
		question_embedding = question_embedding / question_embedding.norm(dim = 1, keepdim = True)
		document_embedding = document_embedding / document_embedding.norm(dim = 1, keepdim = True)
#Cosine Similarity
		cosine_sim = torch.matmul(question_embedding, document_embedding.T)

		return cosine_sim

class BiEncoderLossBatch(nn.Module):
	def __init__(self):
		super(BiEncoderLossBatch, self).__init__()
		self.bceloss = nn.BCEWithLogitsLoss()


	def forward(self, scores, targets):
"""scores: output of the BiEncoder Model. NXN tensor
targets: actual 0/1 targets. NXN tensor"""

		scores = scores.flatten()
		targets = targets.flatten()

		loss = self.bceloss(scores, targets)

		return loss
  1. I don’t see any non-linearity being used here (eg: sigmoid or ReLU). Are your loss values decreasing during training? If not, your model is not learning.
  2. Instead of using cosine_sim, I would rather do this:
    a. Append the question-document pairs and pass them through the pretrained BERT model.
    b. Add a linear layer for the output from the first token (assuming it’s CLS token) and use BCELoss.

Hi. Thanks for the suggestions. I was able to figure out the issues.

The reason I wanted to find similarities between the embeddings of the Question and Documents is for fast and easy search. During deployment, the different document embeddings can be precomputed and stored. When a question is given, we can easily find its embedding, and find the nearest neighbour document embedding of the question embedding, thereby saving a lot of time.