Replace features with features of K neighbors without using for loop

In a KNN-GCN model, I am trying to replace orginal features with features of K-nearest neighbors. The size of input should be (B,N,D), with B: the batch size, N: the number of channels, D: the dimention of features. The size of output should be (B,N,K,D), with K: the parameter of KNN.

I first tried to do this with repeat and gather. However, this would form a tensor of size (B,N,N,D), which is a waste of GPU memory as N is large. The code is:

def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx

top_idx = knn(x.permute(0, 2, 1), K)
top_idx = top_idx.unsqueeze(3).repeat(1, 1, 1, D)
output = torch.gather(x.unsqueeze(1).repeat(1, N, 1, 1), 2, top_idx)

The returned value of function knn() is a tensor of size (B,N,K).

To avoid generating the tensor of size (B,N,N,D), I used a for loop which repeat K times. The code is:

top_idx = knn(x.permute(0, 2, 1), K)
top_idx = top_idx.unsqueeze(3).repeat(1, 1, 1, D)
output=torch.empty(size=(B,N,1,D))
for i in range(K):
    temp=torch.gather(x,1,top_idx[:,:,i,:].view(B,N,D)).view(B,N,1,D)
    output=torch.cat((output,temp),dim=2)
output=output[:,:,1:,:]

Is there any way that I can replace features with features of K neighbors without using for loop? The size of input is (B,N,D). The size of KNN index is (B,N,K). The size of output should be (B,N,K,D).