Similar to torch.gather over two dimensions

Hello,

I would like to ask for your help on making a vectorized version of the following simple operation.

I have a 3D tensor x of shape (N, B, V) and I would like to get the elements of x at indices given by two (N, K) tensors idx1 and idx2 as follows: y[i, j] = x[i, idx1[i,j], idx2[i,j]].

Using a for loop, this can be done using the following function:

def f(x, idx1, idx2):
    """Compute using for loop
        x: N x B x V
        idx1: N x K matrix where idx1[i, j] is between [0, B)
        idx2: N x K matrix where idx2[i, j] is between [0, V)
    Return:
        y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]]
    """
    # N x K
    y = torch.zeros(idx1.shape)
    N, K = idx1.shape
    for i in range(N):
        for j in range(K):
            y[i, j] = x[i, idx1[i,j], idx2[i,j]]
    return y

It seems that torch.gather is not applicable here. If we do z = x[:,idx1,idx2], then it remains to do y[i,j] = z[i,i,j], which doesn’t seem to be any easier.

For your convenience the above function can be tested using the following:

def main():
    N = 2
    B = 3
    V = 10
    K = B*2
    x = torch.randn(N, B, V)
    idx1 = torch.randint(0, B, size=(N, K))
    idx2 = torch.randint(0, V, size=(N, K))
    y = f(x, idx1, idx2)

Thank you in advance for your help!

I have finally found a solution but I’m not sure if it’s the most efficient one:

y = torch.einsum('iij->ij', x[:,idx1,idx2])

Hi,

Otherwise, you can linearize the two dimensions into one and then use gather like this:

import torch

def f(x, idx1, idx2):
    """Compute using for loop
        x: N x B x V
        idx1: N x K matrix where idx1[i, j] is between [0, B)
        idx2: N x K matrix where idx2[i, j] is between [0, V)
    Return:
        y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]]
    """
    # N x K
    y = torch.zeros(idx1.shape)
    N, K = idx1.shape
    for i in range(N):
        for j in range(K):
            y[i, j] = x[i, idx1[i,j], idx2[i,j]]
    return y

def f2(x, idx1, idx2):
    # Linearize the last two dims and index in a contiguous x
    x = x.contiguous()

    lin_idx = idx2 + x.size(-1) * idx1
    x = x.view(-1, x.size(1) * x.size(2))

    return x.gather(-1, lin_idx)

N = 2
B = 3
V = 10
K = B*2
x = torch.randn(N, B, V)
idx1 = torch.randint(0, B, size=(N, K))
idx2 = torch.randint(0, V, size=(N, K))
y = f(x, idx1, idx2)
y2 = f2(x, idx1, idx2)

print((y-y2).abs().max())
3 Likes

Wow, that was fast! Thank you so much, @albanD! Let me try your solution to see how fast it is.

1 Like

@albanD No surprise, your version is super fast! Using the following benchmark code, I obtained on my Mac (CPU):

10 trials took:
f1: 2.2150492668151855
f2: 0.36966371536254883
f3: 0.004826068878173828

Thank you so much again!

import time
import torch

def f1(x, idx1, idx2):
    """Using for loop
        x: N x B x V
        idx1: N x K matrix where idx1[i, j] is between [0, B)
        idx2: N x K matrix where idx2[i, j] is between [0, V)
    Return:
        y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]]
    """
    # N x K
    y = torch.zeros(idx1.shape)
    N, K = idx1.shape
    for i in range(N):
        for j in range(K):
            y[i, j] = x[i, idx1[i,j], idx2[i,j]]
    return y

def f2(x, idx1, idx2):
    """Same as above but more efficient
    """
    # N x K
    return torch.einsum('iij->ij', x[:,idx1,idx2])


def f3(x, idx1, idx2):
    """Same as above, but even faster
    https://discuss.pytorch.org/t/similar-to-torch-gather-over-two-dimensions/118827/3?u=f10w
    """
    # Linearize the last two dims and index in a contiguous x
    x = x.contiguous()

    lin_idx = idx2 + x.size(-1) * idx1
    x = x.view(-1, x.size(1) * x.size(2))

    return x.gather(-1, lin_idx)


def main():
    N = 100
    B = 100
    V = 10000
    K = B*2

    repeat = 10
    t1 = 0
    t2 = 0
    t3 = 0

    for _ in range(repeat):
        x = torch.randn(N, B, V)
        idx1 = torch.randint(0, B, size=(N, K))
        idx2 = torch.randint(0, V, size=(N, K))

        start = time.time()
        y1 = f1(x, idx1, idx2)
        t1 += time.time() - start

        start = time.time()
        y2 = f2(x, idx1, idx2)
        t2 += time.time() - start

        start = time.time()
        y3 = f3(x, idx1, idx2)
        t3 += time.time() - start

        assert (y1-y2).abs().max() < 1e-10
        assert (y1-y3).abs().max() < 1e-10

    print(f'{repeat} trials took:\nf1: {t1}\nf2: {t2}\nf3: {t3}')

if __name__ == "__main__":
    main()
2 Likes