Find the L2 Nearest Neighbor in a queue?

Hi :blush: , I’m having a hard time figuring out how to compute the following efficiently. What I would like to do is, for each element of x, compute the L2 difference with all the other elements in the queue. What I came up with is something like:

inp_size = 10
queue_size = 50
dim = 64

x = torch.randn([inp_size, dim])
queue = torch.randn([queue_size, dim])

score = (x.repeat_interleave(queue_size, 0) - queue.repeat(inp_size, 1)).pow(2).sum(1)
nn_idx = score.view(inp_size, queue_size).min(1)[1]

any advice?