I am implementing the dilated k Neares Neighbour algorithm. The algorithm unfortunately has nested loops. The presence of loops severely hampers the execution speed.
import torch
dilation=3
nbd_size=5
knn_key=torch.randint([0,30,(64,12,198,100)])
dilated_keys=torch.zeros([knn_key.shape[0],knn_key.shape[1],knn_key.shape[2],nbd_size])
for i in range(knn_key.shape[0]):
for j in range(knn_key.shape[1]):
for k in range(knn_key.shape[2]):
list_indices=[]
while (len(list_indices))<nbd_size:
for l in range(knn_key.shape[3]):
if knn_key[i][j][k][l]%dilation==k%dilation:
list_indices.append(knn_key[i][j][k][l])
if (len(list_indices))>=nbd_size:
break
list_indices_tensor=torch.tensor(list_indices)
dilated_keys[i][j][k]=list_indices_tensor
The variable knn_key
stores the 100
nearest neighbours among the 1000 data points originally available. The dilated_keys
stores the nbd_size=5
selected indices of the neighbours that are used after applying dialation filter. Any help to use broadcasting solution to remove the three nested loops will be highly helpful.