I assume you are referring to this post.

If so, then this batched approach should work:

```
data = torch.randn(100, 10)
test_batch = torch.randn(5, 10)
for test in test_batch:
test = test.unsqueeze(0)
dist = torch.norm(data - test, dim=1, p=None)
knn = dist.topk(3, largest=False)
print('kNN dist: {}, index: {}'.format(knn.values, knn.indices))
> kNN dist: tensor([2.2806, 2.4122, 2.5661]), index: tensor([97, 26, 13])
kNN dist: tensor([2.7386, 2.8757, 3.0352]), index: tensor([87, 92, 26])
kNN dist: tensor([1.8009, 1.8642, 2.3949]), index: tensor([75, 15, 37])
kNN dist: tensor([2.1696, 2.2909, 2.5022]), index: tensor([80, 76, 28])
kNN dist: tensor([1.9649, 2.0842, 2.2254]), index: tensor([27, 62, 58])
dist2 = torch.norm(data.unsqueeze(1) - test_batch.unsqueeze(0), dim=2, p=None)
knn2 = dist2.topk(3, largest=False, dim=0)
print('kNN dist:\n{}\nindex:\n{}'.format(knn2.values, knn2.indices))
> kNN dist:
tensor([[2.2806, 2.7386, 1.8009, 2.1696, 1.9649],
[2.4122, 2.8757, 1.8642, 2.2909, 2.0842],
[2.5661, 3.0352, 2.3949, 2.5022, 2.2254]])
index:
tensor([[97, 87, 75, 80, 27],
[26, 92, 15, 76, 62],
[13, 26, 37, 28, 58]])
```