I try to write a model which includes a batched 2D index selection step, want to ask the best/fast way to perform this without any for loops.

Input:

B: batch_size, N: length of a matrix, K1, K2 number of candidates; K1 << N, K2 << N.

X: is a 3D binary tensor, shape = [B, N, N], X[b][n][n] \in {0, 1}

S: is a 2D index tensor, shape = [B, K1], S[b][k1] \in {0,1,2,…N-1}

E: is a 2D tensor, shape = [B, K2], E[b][k2] \in {0,1,2,…N-1}

Output:

C: is a 2D tensor, shape = [S, 3] for each c in C satisfy that X[c[0]][c[1]][c[2]] = 1 and c[1] \in S[c[0], :] and c[2] \in E[c[0], :].