Training bert-base cross encoder with triplet loss take too long

Hello.

I’m trying to train a cross encoder using bert base model. And It takes me 1400 hours to train a epoch with 330k batch per epoch (already used GPU)

In this model I use triplet loss and during training I generate hardest anchor_negative pairs by call model in inference mode. to get the score of all negative candidates with an anchor sample then rank to get k hardest negative samples. Compare with generate random negative samples. It takes 15 seconds and the random method just takes 1.3 second. (I checked that calling base model takes most of the time)

Is there any strategy to for this problem? Or is it longer than usual?

Thanks you for any help

def __iter__(self):
    # get list of positive pair indices and negative pair indices from pair indices labels
    all_pos_pair_indices = np.where(self.pair_indices_labels[:,0] == 1)[0]
    global mention_mention_model
    start_pos_index = 0
    k = 0
    neg_indices_batches = {}

    for i in range(self.num_iterations):
      batch_indices = [] 
      pos_batch_indices = all_pos_pair_indices[start_pos_index:start_pos_index + self.num_pos_per_batch ]
      start_pos_index += self.num_pos_per_batch
      
      with torch.no_grad():
        neg_batch_indices = []
        for i, pos_pair in enumerate(self.pair_indices_labels[pos_batch_indices]):
          anchor_idx = pos_pair[1]
          # get all candidate from pair_indices first
          if anchor_idx not in self.top_k_neg_dict or (self.neg_index_dict[anchor_idx] == len(self.top_k_neg_dict[anchor_idx])):
            neg_candidates = np.array(self.neg_indices_dict[anchor_idx])

            # narrow down to top-k
            neg_candidates_pairs = self.pair_indices_labels[neg_candidates]
            input_tokens_buffer = []
            mention_mask_a_buffer = []
            mention_mask_b_buffer = []

            for neg_pair in neg_candidates_pairs:
              label = neg_pair[0]
              mention_a_idx = neg_pair[1]
              mention_b_idx = neg_pair[2]
              input_tokens, mention_mask_a, mention_mask_b = AffinityDataset.generate_affinity_model_input(mention_a_idx, mention_b_idx, self.training_mention_tokens, self.training_mention_pos, self.max_len, self.tokenizer)
              
              input_tokens_buffer.append(input_tokens)
              mention_mask_a_buffer.append(mention_mask_a)
              mention_mask_b_buffer.append(mention_mask_b)

            input_tokens_buffer = torch.stack(input_tokens_buffer).to(device)
            mention_mask_a_buffer = torch.stack(mention_mask_a_buffer).to(device)  
            mention_mask_b_buffer = torch.stack(mention_mask_b_buffer).to(device) 
            neg_affin = mention_mention_model(input_tokens_buffer, mention_mask_a_buffer, mention_mask_b_buffer)
            if neg_affin.shape[0] < self.top_k_neg:
              top_k_neg = torch.topk(neg_affin, neg_affin.shape[0], dim=0, largest=False, sorted=False)[1].cpu()
            else:
              top_k_neg = torch.topk(neg_affin,  self.top_k_neg, dim=0, largest=False, sorted=False)[1].cpu()

            self.top_k_neg_dict[anchor_idx] = neg_candidates[top_k_neg].flatten().tolist()
            self.neg_index_dict[anchor_idx] = 0
            #top_k_neg = 
            #neg_batch_indices.extend(top_k_neg.flatten().tolist())
          neg_batch_indices.extend([self.top_k_neg_dict[anchor_idx][self.neg_index_dict[anchor_idx]]])
          self.neg_index_dict[anchor_idx] += 1

      batch_indices.extend(pos_batch_indices)
      batch_indices.extend(neg_batch_indices)


      yield batch_indices