Hello,
I would like to ask for your help on making a vectorized version of the following simple operation.
I have a 3D tensor x of shape (N, B, V) and I would like to get the elements of x at indices given by two (N, K) tensors idx1 and idx2 as follows: y[i, j] = x[i, idx1[i,j], idx2[i,j]].
Using a for loop, this can be done using the following function:
def f(x, idx1, idx2):
    """Compute using for loop
        x: N x B x V
        idx1: N x K matrix where idx1[i, j] is between [0, B)
        idx2: N x K matrix where idx2[i, j] is between [0, V)
    Return:
        y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]]
    """
    # N x K
    y = torch.zeros(idx1.shape)
    N, K = idx1.shape
    for i in range(N):
        for j in range(K):
            y[i, j] = x[i, idx1[i,j], idx2[i,j]]
    return y
It seems that torch.gather is not applicable here. If we do z = x[:,idx1,idx2], then it remains to do y[i,j] = z[i,i,j], which doesn’t seem to be any easier.
For your convenience the above function can be tested using the following:
def main():
    N = 2
    B = 3
    V = 10
    K = B*2
    x = torch.randn(N, B, V)
    idx1 = torch.randint(0, B, size=(N, K))
    idx2 = torch.randint(0, V, size=(N, K))
    y = f(x, idx1, idx2)
Thank you in advance for your help!