Hello. I’m trying to train my model with in-batch loss. Following is the key part of the loss computation.
- There are three representations for (query, positive, negative).
- The shape of
in_batch_scores
is (bsize, bsize). The diagonal elements are the target scores for cross entropy, which are the dot product between queries and positives. - I additionally compute
neg_scores
between queries and negatives, concat two scores computed so far, then finally feed the scores tocross_entropy()
The problem is the loss only oscillates and does not decrease. I checked all the parameters()
as well as
tensors marked as trainable. What would be the reason?
def _compute_ibl(self, rep_q, rep_pos, rep_neg):
# scores
rep_q_cloned = rep_q.clone()
in_batch_scores = rep_q @ rep_pos.T # (bs, bs)
neg_scores = (rep_q_cloned * rep_neg).sum(dim=1).unsqueeze(dim=1) # (bs, 1)
scores = torch.cat([in_batch_scores, neg_scores], dim=1) # (bs, bs+1)
# compute loss
# temp = torch.cat([in_batch_scores, neg_scores], dim=1) # (bs, bs+1)
# scores = F.log_softmax(temp, dim=1)
# nb_columns = in_batch_scores.shape[1]
# return torch.mean(
# -scores[
# torch.arange(in_batch_scores.shape[0]),
# torch.arange(nb_columns),
# ]
# )
# compute loss
target = torch.arange(
in_batch_scores.shape[0], device=scores.device, dtype=torch.long
)
return F.cross_entropy(scores, target)