Hello. I’m trying to train my model with in-batch loss. Following is the key part of the loss computation.
- There are three representations for (query, positive, negative).
- The shape of
in_batch_scoresis (bsize, bsize). The diagonal elements are the target scores for cross entropy, which are the dot product between queries and positives.
- I additionally compute
neg_scoresbetween queries and negatives, concat two scores computed so far, then finally feed the scores to
The problem is the loss only oscillates and does not decrease. I checked all the
parameters() as well as
tensors marked as trainable. What would be the reason?
def _compute_ibl(self, rep_q, rep_pos, rep_neg): # scores rep_q_cloned = rep_q.clone() in_batch_scores = rep_q @ rep_pos.T # (bs, bs) neg_scores = (rep_q_cloned * rep_neg).sum(dim=1).unsqueeze(dim=1) # (bs, 1) scores = torch.cat([in_batch_scores, neg_scores], dim=1) # (bs, bs+1) # compute loss # temp = torch.cat([in_batch_scores, neg_scores], dim=1) # (bs, bs+1) # scores = F.log_softmax(temp, dim=1) # nb_columns = in_batch_scores.shape # return torch.mean( # -scores[ # torch.arange(in_batch_scores.shape), # torch.arange(nb_columns), # ] # ) # compute loss target = torch.arange( in_batch_scores.shape, device=scores.device, dtype=torch.long ) return F.cross_entropy(scores, target)