I am currently using top-k to select elements from the similarity matrix

So for each x, we can find k elements from y.

How could I use the indices to get something like the below without loops.

```
import torch
import torch.nn.functional as F
x = torch.randn((10,64,512)) # (b, x_c, d)
y = torch.randn((10,128,512)) # (b, y_c, d)
sim_matrix = torch.matmul(F.normalize(x, p=2, dim=-1), F.normalize(y, p=2, dim=-1).permute(0, 2, 1)) # (b, x_c, y_c)
_, sim_matrix_ind = sim_matrix.topk(5, dim=2) # (b, x_c, 5)
select_y = torch.gather(y, dim=1,) # Maybe (b , x_c, 5, d)
```