Recommender system custom loss with kNN not working

Hi everyone

I’m trying to train a recommender system that takes as input a query (1xN-dim vector), an index(QxN-dim vector) and performs a kNN search to find the k closest (most similar) arrays. For the supervised training, I want compute the mean precision at k score (here a post about the metric). In other words, if the labels of the k recommended indexes are the same as the query the score is higher. The model is a simple MLP that works on the output of another model that extracts some embeddings. I want to use 1 minus the score as loss function to train my model but this doesn’t work.

Here is my loss function and a minimum reproducible example:

def knn_mpatk_loss(query_arrays, query_labels, index_arrays, index_labels, k=3):
    
    queries = query_arrays.shape[0]
    
    distances = torch.cdist(query_embeds, index_embeds)
    knn = distances.topk(k, largest=False)
    nbrs = knn.indices
    nn_labels = index_labels[nbrs]
    mAP = torch.sum(nn_labels == torch.unsqueeze(query_labels,1))/(k*queries)
    
    return mAP

rnd_query_embeds = torch.tensor(np.random.random((10,100)), requires_grad=True)
rnd_query_labels = torch.tensor(np.random.randint(0,5,(10))/1.0, requires_grad=True)
rnd_index_embeds = torch.tensor(np.random.random((100,100)), requires_grad=True)
rnd_index_labels = torch.tensor(np.random.randint(0,5,(100))/1.0, requires_grad=True)

loss = 1 - knn_mpatk_loss(rnd_query_embeds, rnd_query_labels, rnd_index_embeds, rnd_index_labels, k=3)

print(loss.requires_grad)

I got False for the print(loss.requires_grad) and of course, if I start the training I got:

----------
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Thank you.

EDIT:
After a more accurate inspection, I think the problem might be in the boolean comparison when generating mAP. Any suggestion on that step?

EDIT2:
The code now is reproducible and it can be seen that the output of print(loss.requires_grad) is False.

Can you share what the distance function is? (And make sure the example code you share is a complete minimal reproducible example? So people can debug it themselves?)

If you want a torch equivalent version to == you can just use torch.equal

1 Like

Hi, thank you for your answer. About the distances I made a typo and corrected it. I tried to use Torch.equal and I get the same issue. To ensure reproducibility, I added the code for 4 random tensors and check if the function output has the requires_grad=True