Recommender system custom loss with kNN not working

Hi everyone

I’m trying to train a recommender system that takes as input a query (1xN-dim vector), an index(QxN-dim vector) and performs a kNN search to find the k closest (most similar) arrays. For the supervised training, I want compute the mean precision at k score (here a post about the metric). In other words, if the labels of the k recommended indexes are the same as the query the score is higher. The model is a simple MLP that works on the output of another model that extracts some embeddings. I want to use 1 minus the score as loss function to train my model but this doesn’t work.

Here is my loss function and a minimum reproducible example:

def knn_mpatk_loss(query_arrays, query_labels, index_arrays, index_labels, k=3):
    queries = query_arrays.shape[0]
    distances = torch.cdist(query_embeds, index_embeds)
    knn = distances.topk(k, largest=False)
    nbrs = knn.indices
    nn_labels = index_labels[nbrs]
    mAP = torch.sum(nn_labels == torch.unsqueeze(query_labels,1))/(k*queries)
    return mAP

rnd_query_embeds = torch.tensor(np.random.random((10,100)), requires_grad=True)
rnd_query_labels = torch.tensor(np.random.randint(0,5,(10))/1.0, requires_grad=True)
rnd_index_embeds = torch.tensor(np.random.random((100,100)), requires_grad=True)
rnd_index_labels = torch.tensor(np.random.randint(0,5,(100))/1.0, requires_grad=True)

loss = 1 - knn_mpatk_loss(rnd_query_embeds, rnd_query_labels, rnd_index_embeds, rnd_index_labels, k=3)


I got False for the print(loss.requires_grad) and of course, if I start the training I got:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Thank you.

After a more accurate inspection, I think the problem might be in the boolean comparison when generating mAP. Any suggestion on that step?

The code now is reproducible and it can be seen that the output of print(loss.requires_grad) is False.

Can you share what the distance function is? (And make sure the example code you share is a complete minimal reproducible example? So people can debug it themselves?)

If you want a torch equivalent version to == you can just use torch.equal

Hi, thank you for your answer. About the distances I made a typo and corrected it. I tried to use Torch.equal and I get the same issue. To ensure reproducibility, I added the code for 4 random tensors and check if the function output has the requires_grad=True