Retrieval Mean Reciprocal Rank (MRR) in torch model training

Suppose that after a forward pass, the model outputs a prediction as shown below:

{
    "query_idx": tensor([0, 1, ..., 2]),
    "query_rpr": tensor([
        [0.1790, 0.4046, ..., 0.5882],
            ...
        [0.1207, 0.6405, ..., 0.0214]
    ]),
    "doc_idx": tensor([9, 5, ..., 7]),
    "doc_rpr": tensor([
        [0.290, 0.1045, ..., 0.8852],
            ...
        [0.774, 0.4056, ..., 0.1012]
    ])
}

where:

  • query_idx: is the query index;
  • query_rpr: is the query embedding;
  • doc_idx: is the document idx;
  • doc_rpr: is the document representation.

and the relevance map contains:

{ # query_idx -> (relevant) doc_ids
    0: [31, 74, ..., 85]
    1: [15, 18, ..., 91] 
        ...
    541247: [17, 32, ..., 49]
}
    

Therefore, how compute the MRR metric during model training?