What would be the reason the loss does not decrease (in-batch loss)?

Hello. I’m trying to train my model with in-batch loss. Following is the key part of the loss computation.

  • There are three representations for (query, positive, negative).
  • The shape of in_batch_scores is (bsize, bsize). The diagonal elements are the target scores for cross entropy, which are the dot product between queries and positives.
  • I additionally compute neg_scores between queries and negatives, concat two scores computed so far, then finally feed the scores to cross_entropy()

The problem is the loss only oscillates and does not decrease. I checked all the parameters() as well as
tensors marked as trainable. What would be the reason?

def _compute_ibl(self, rep_q, rep_pos, rep_neg):
    # scores
    rep_q_cloned = rep_q.clone()
    in_batch_scores = rep_q @ rep_pos.T  # (bs, bs)
    neg_scores = (rep_q_cloned * rep_neg).sum(dim=1).unsqueeze(dim=1)  # (bs, 1)
    scores = torch.cat([in_batch_scores, neg_scores], dim=1)  # (bs, bs+1)

    # compute loss
    # temp = torch.cat([in_batch_scores, neg_scores], dim=1)  # (bs, bs+1)
    # scores = F.log_softmax(temp, dim=1)
    # nb_columns = in_batch_scores.shape[1]
    # return torch.mean(
    #     -scores[
    #         torch.arange(in_batch_scores.shape[0]),
    #         torch.arange(nb_columns),
    #     ]
    # )

    # compute loss
    target = torch.arange(
        in_batch_scores.shape[0], device=scores.device, dtype=torch.long
    )

    return F.cross_entropy(scores, target)

Just as a sanity check, have you checked whether the parameters that are expected to change every iteration are indeed changing (e.g., with something like params.sum())?

And (just because it wasn’t mentioned in your post), have you tried decreasing the learning rate?

Thanks for the suggestion. I actually found a solution! Because of the bug in the code the learning rate schedule was set to ‘inf’, which kept the effective learning rate low, preventing the model from actually training. Anyway it was related to learning rate :slight_smile: Thanks!