Say I have a top-k indexing matrix P (B*N*k)
, a weight matrix W(B*N*N)
and a target matrix A (B*N*N)
, I want to get a adjacent matrix that operates as the following loops:
for i in range(B):
for ii in range(N):
for j in range(k):
if weighted:
A[i][ii][P[i][ii][j]] = W[i][ii][P[i][ii][j]]
else:
A[i][ii][P[i][ii][j]] = 1
How to implement it more efficiently and concisely?