Is there a dilated k-nearest neighbour solution available fast execution?

I am implementing the dilated k Neares Neighbour algorithm. The algorithm unfortunately has nested loops. The presence of loops severely hampers the execution speed.

import torch
dilation=3
nbd_size=5
knn_key=torch.randint([0,30,(64,12,198,100)])


dilated_keys=torch.zeros([knn_key.shape[0],knn_key.shape[1],knn_key.shape[2],nbd_size])

for i in range(knn_key.shape[0]):
    for j in range(knn_key.shape[1]):
        for k in range(knn_key.shape[2]):
            list_indices=[]
            while (len(list_indices))<nbd_size:
                for l in range(knn_key.shape[3]):
                    if knn_key[i][j][k][l]%dilation==k%dilation:
                        list_indices.append(knn_key[i][j][k][l])
                        if (len(list_indices))>=nbd_size:
                            break
            list_indices_tensor=torch.tensor(list_indices)
            dilated_keys[i][j][k]=list_indices_tensor

The variable knn_key stores the 100 nearest neighbours among the 1000 data points originally available. The dilated_keys stores the nbd_size=5 selected indices of the neighbours that are used after applying dialation filter. Any help to use broadcasting solution to remove the three nested loops will be highly helpful.

Looks like the inner loop can be reduced as follows,

import torch

dilation = 3
nbd_size = 5
knn_key = torch.randint(0, 30, (64, 12, 198, 100))

dilated_keys = torch.zeros((knn_key.shape[0], knn_key.shape[1], knn_key.shape[2], nbd_size), dtype=torch.int64)

for i in range(knn_key.shape[0]):
    for j in range(knn_key.shape[1]):
        for k in range(knn_key.shape[2]):
            key = knn_key[i, j, k]
            indices = torch.nonzero(key % dilation == k % dilation).squeeze()
            selected_indices = indices[:nbd_size]
            dilated_keys[i, j, k] = key[selected_indice

However three nested loops are still there. Any help will be highly helpful. [1]