How to find closest point in two point sets

Hi,
I have two sets of point in 3D ,as X and Y, of the same shape (Batchsize, Num_Points, 3).

I’m trying to find the index of the closest point in Y for every point in X, and vice versa, the index of the closest point in X for every point in Y.

I’m wondering is there any function in PyTorch I can directly apply?

Thank you in advance for your help!

You can compute a distance matrix first and then find the smallest one.