Indexing adj matrix of a knn graph

Say I have a top-k indexing matrix P (B*N*k), a weight matrix W(B*N*N) and a target matrix A (B*N*N), I want to get a adjacent matrix that operates as the following loops:

for i in range(B):
         for ii in range(N):
             for j in range(k):
                 if weighted:
                     A[i][ii][P[i][ii][j]] = W[i][ii][P[i][ii][j]]
                 else:
                     A[i][ii][P[i][ii][j]] = 1

How to implement it more efficiently and concisely?

You could use a scatter_/gather approach as seen here:

B, N, k = 2, 3, 4

P = torch.randint(0, N, (B, N, k))
W = torch.randn(B, N, N)
A = torch.zeros(B, N, N)
A_ = A.clone()

# loop
for i in range(B):
    for ii in range(N):
        for j in range(k):
            A[i][ii][P[i][ii][j]] = W[i][ii][P[i][ii][j]]

# scatter/gather approach
A_.scatter_(2, P, torch.gather(W, 2, P))

# check
print((A == A_).all())
> tensor(True)